Repository: LeelaChessZero/lczero-training Branch: master Commit: a60c7d281ebc Files: 219 Total size: 1.1 MB Directory structure: gitextract_k9ephqoc/ ├── .clang-format ├── .gitignore ├── .gitmodules ├── AGENTS.md ├── README.md ├── csrc/ │ ├── loader/ │ │ ├── chunk_source/ │ │ │ ├── chunk_source.h │ │ │ ├── chunk_source_view.h │ │ │ ├── debug_chunk_source.cc │ │ │ ├── debug_chunk_source.h │ │ │ ├── rawfile_chunk_source.cc │ │ │ ├── rawfile_chunk_source.h │ │ │ ├── tar_chunk_source.cc │ │ │ └── tar_chunk_source.h │ │ ├── data_loader.cc │ │ ├── data_loader.h │ │ ├── data_loader_metrics.cc │ │ ├── data_loader_metrics.h │ │ ├── data_loader_test.cc │ │ ├── frame_type.h │ │ ├── loader_main.cpp │ │ ├── pybind_module.cc │ │ └── stages/ │ │ ├── chunk_rescorer.cc │ │ ├── chunk_rescorer.h │ │ ├── chunk_rescorer_test.cc │ │ ├── chunk_source_loader.cc │ │ ├── chunk_source_loader.h │ │ ├── chunk_source_loader_test.cc │ │ ├── chunk_source_splitter.cc │ │ ├── chunk_source_splitter.h │ │ ├── chunk_source_splitter_test.cc │ │ ├── chunk_unpacker.cc │ │ ├── chunk_unpacker.h │ │ ├── chunk_unpacker_test.cc │ │ ├── file_path_provider.cc │ │ ├── file_path_provider.h │ │ ├── file_path_provider_main.cc │ │ ├── file_path_provider_test.cc │ │ ├── join_stage.cc │ │ ├── join_stage.h │ │ ├── join_stage_test.cc │ │ ├── position_sampling.cc │ │ ├── position_sampling.h │ │ ├── shuffling_chunk_pool.cc │ │ ├── shuffling_chunk_pool.h │ │ ├── shuffling_chunk_pool_test.cc │ │ ├── shuffling_frame_sampler.cc │ │ ├── shuffling_frame_sampler.h │ │ ├── shuffling_frame_sampler_test.cc │ │ ├── simple_chunk_extractor.cc │ │ ├── simple_chunk_extractor.h │ │ ├── simple_chunk_extractor_test.cc │ │ ├── stage.cc │ │ ├── stage.h │ │ ├── stage_factory.cc │ │ ├── stage_factory.h │ │ ├── stage_factory_test.cc │ │ ├── tensor_generator.cc │ │ ├── tensor_generator.h │ │ ├── tensor_generator_test.cc │ │ └── training_chunk.h │ ├── tools/ │ │ ├── dump_chunk_main.cc │ │ ├── filter_chunks_main.cc │ │ ├── position_weight_stats_main.cc │ │ ├── rescore_chunk_main.cc │ │ ├── result_distribution_main.cc │ │ └── startpos_policy_distribution_main.cc │ └── utils/ │ ├── gz.cc │ ├── gz.h │ ├── metrics/ │ │ ├── exponential_aggregator.h │ │ ├── group.h │ │ ├── load_metric.h │ │ ├── load_metric_test.cc │ │ ├── printer.h │ │ ├── statistics_metric.h │ │ └── stats_test.cc │ ├── queue.h │ ├── queue_test.cc │ ├── stream_shuffler.cc │ ├── stream_shuffler.h │ ├── stream_shuffler_test.cc │ ├── tensor.h │ ├── tensor_test.cc │ ├── thread_pool.h │ ├── training_data_printer.cc │ └── training_data_printer.h ├── docs/ │ ├── README.md │ ├── architecture.md │ ├── checkpoint_migration.md │ ├── example.textproto │ ├── heads.md │ ├── index.md │ ├── loader.md │ ├── new_stage.md │ ├── overview.md │ ├── shuffling_pool_hanse_sampling.md │ ├── training_tuple.md │ ├── tui.md │ └── weights_tool.md ├── init.sh ├── justfile ├── meson.build ├── native.ini ├── proto/ │ ├── checkpoint_migration_config.proto │ ├── data_loader_config.proto │ ├── export_config.proto │ ├── metrics_config.proto │ ├── model_config.proto │ ├── root_config.proto │ ├── stage_control.proto │ ├── training_config.proto │ └── training_metrics.proto ├── pyproject.toml ├── scripts/ │ ├── diff.py │ ├── fixorder.py │ ├── init.sh │ ├── initsplit.py │ ├── inittrainingname.py │ ├── pack.py │ ├── purge.py │ ├── rescore.sh │ ├── shuffle.py │ ├── split.sh │ ├── stage.sh │ ├── unpack.py │ └── upload.sh ├── src/ │ ├── lczero_training/ │ │ ├── __init__.py │ │ ├── _lczero_training.pyi │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ ├── backfill_metrics.py │ │ │ ├── common.py │ │ │ ├── daemon.py │ │ │ ├── dataloader_viz.py │ │ │ ├── describe_training.py │ │ │ ├── jax2leela.py │ │ │ ├── leela2jax.py │ │ │ ├── migrate_checkpoint.py │ │ │ ├── overfit.py │ │ │ ├── test_dataloader.py │ │ │ ├── train.py │ │ │ ├── training_eval.py │ │ │ ├── training_init.py │ │ │ ├── tui.py │ │ │ ├── tune_lr.py │ │ │ └── weights_tool.py │ │ ├── convert/ │ │ │ ├── __init__.py │ │ │ ├── jax_to_leela.py │ │ │ ├── leela_pytree_visitor.py │ │ │ ├── leela_to_jax.py │ │ │ └── leela_to_modelconfig.py │ │ ├── daemon/ │ │ │ ├── __init__.py │ │ │ ├── daemon.py │ │ │ ├── metrics.py │ │ │ ├── metrics_base.py │ │ │ ├── pipeline.py │ │ │ ├── protocol/ │ │ │ │ ├── __init__.py │ │ │ │ ├── communicator.py │ │ │ │ ├── messages.py │ │ │ │ └── registry.py │ │ │ └── rms_metrics.py │ │ ├── dataloader/ │ │ │ └── __init__.py │ │ ├── model/ │ │ │ ├── __init__.py │ │ │ ├── embedding.py │ │ │ ├── encoder.py │ │ │ ├── loss_function.py │ │ │ ├── model.py │ │ │ ├── movesleft_head.py │ │ │ ├── policy_head.py │ │ │ ├── shared.py │ │ │ ├── utils.py │ │ │ └── value_head.py │ │ ├── py.typed │ │ ├── tests/ │ │ │ ├── test_protobuf.py │ │ │ ├── test_protocol_registry.py │ │ │ └── test_weights_tool.py │ │ ├── tools/ │ │ │ ├── __init__.py │ │ │ ├── weight_codecs.py │ │ │ ├── weight_wrappers.py │ │ │ └── weights_tool.py │ │ ├── training/ │ │ │ ├── __init__.py │ │ │ ├── backfill_metrics.py │ │ │ ├── dataloader_probe.py │ │ │ ├── describe.py │ │ │ ├── eval.py │ │ │ ├── init.py │ │ │ ├── lr_schedule.py │ │ │ ├── migrate_checkpoint.py │ │ │ ├── optimizer.py │ │ │ ├── overfit.py │ │ │ ├── state.py │ │ │ ├── tensorboard.py │ │ │ ├── test_lr_schedule.py │ │ │ ├── training.py │ │ │ ├── tune_lr.py │ │ │ └── utils.py │ │ └── tui/ │ │ ├── __init__.py │ │ ├── app.py │ │ ├── app.tcss │ │ ├── data_pipeline_pane.py │ │ ├── dataloader_widgets.py │ │ ├── log_pane.py │ │ └── training_widgets.py │ └── proto/ │ └── __init__.py └── tf/ ├── attention_policy_map.py ├── chunkparsefunc.py ├── chunkparser.py ├── configs/ │ └── example.yaml ├── decode_training.py ├── lc0_az_policy_map.py ├── make_model.py ├── model_to_net.py ├── net.py ├── net_to_model.py ├── policy_index.py ├── requirements.txt ├── shufflebuffer.py ├── start.sh ├── tfprocess.py ├── train.py └── update_steps.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ BasedOnStyle: Google ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class *_pb2.py # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # protobuf stuff tf/proto/ **/*_pb2.py **/*_pb2.pyi # Meson builddir/ subprojects/ # LLM Agents # Use AGENTS.md; symlink other files to it if needed. .claude/ CLAUDE.local.md .claude.json CLAUDE.md GEMINI.md # IDE .vscode/ .idea/ ================================================ FILE: .gitmodules ================================================ [submodule "libs/lc0"] path = libs/lc0 url = https://github.com/LeelaChessZero/lc0.git ================================================ FILE: AGENTS.md ================================================ # AGENTS.md This repository contains training script for the Leela Chess Zero project. They are being rewritten. * Old code is located in the `tf/` directory. * New python code is located in the `src/` directory. * New C++ code is located in the `csrc/` directory. The old code is Python/TensorFlow-based, new code is Python/JAX-based with modules written in C++, operating through pybind11. The build system for C++ code is meson. During development, the project is built in the `builddir/`. ## Testing and Building * C++ tests use GTest framework * Do not insert Sleeps in tests, it slows down presubmit. Instead use e.g. absl::Notification, or std::future * Tests are defined in `meson.build` with `test()` function * When debugging, don't forget to build them before running `meson test` or `builddir/test` * Run tests: `meson test -C builddir/` * Python tests use `pytest` framework * Do not add custom main function, exception catching to report errors, any "test passed" messages etc. Use `pytest` fixtures and assertions. * Build: `meson compile -C builddir/` from build directory * Format code: `just format` * There is a commit hook that runs `just pre-commit`, which runs tests and checks formatting. You may want to run it before attempting to commit. * We use Google C++ style guide. * That means 80 columns. * That means comments should be in full sentences with periods in the end. * When conditional or loop fits one line, it must be written as one line without braces, for example: `if (condition) return value;` * Prefer `absl` to `std` (e.g. `absl::c_` algorithms, `absl::Mutex`, `absl::StrCat`, etc.) * We use `uv` for Python package and venv management, and to running the application. * Run TUI app: `uv run tui --config=` * Do not attempt to run TUI — it messes up the Agent interface and session has to be killed. Ask me to check it for you manually instead. * Do not commit unless explicitly asked. ## IMPORTANT * NEVER add `# type: ignore` or other ways to mask/silence errors instead of fixing them. * Rely on protobuf default values. DO NOT write code like `config.has_foo() ? config.foo() : default_value;` ## Documentation * Documentation is in the `docs/` directory. * The contents is in [The index](docs/index.md) ================================================ FILE: README.md ================================================ # Training The training pipeline resides in `tf`, this requires tensorflow running on linux (Ubuntu 16.04 in this case). (It can be made to work on windows too, but it takes more effort.) ## Installation Install the requirements under `tf/requirements.txt`. And call `./init.sh` to compile the protobuf files. ## Data preparation In order to start a training session you first need to download training data from https://storage.lczero.org/files/training_data/. Several chunks/games are packed into a tar file, and each tar file contains an hour worth of chunks. Preparing data requires the following steps: ``` wget https://storage.lczero.org/files/training_data/training-run1--20200711-2017.tar tar -xzf training-run1--20200711-2017.tar ``` ## Training pipeline Now that the data is in the right format one can configure a training pipeline. This configuration is achieved through a yaml file, see `training/tf/configs/example.yaml`: ```yaml %YAML 1.2 --- name: 'kb1-64x6' # ideally no spaces gpu: 0 # gpu id to process on dataset: num_chunks: 100000 # newest nof chunks to parse train_ratio: 0.90 # trainingset ratio # For separated test and train data. input_train: '/path/to/chunks/*/draw/' # supports glob input_test: '/path/to/chunks/*/draw/' # supports glob # For a one-shot run with all data in one directory. # input: '/path/to/chunks/*/draw/' training: batch_size: 2048 # training batch total_steps: 140000 # terminate after these steps test_steps: 2000 # eval test set values after this many steps # checkpoint_steps: 10000 # optional frequency for checkpointing before finish shuffle_size: 524288 # size of the shuffle buffer lr_values: # list of learning rates - 0.02 - 0.002 - 0.0005 lr_boundaries: # list of boundaries - 100000 - 130000 policy_loss_weight: 1.0 # weight of policy loss value_loss_weight: 1.0 # weight of value loss path: '/path/to/store/networks' # network storage dir model: filters: 64 residual_blocks: 6 ... ``` The configuration is pretty self explanatory, if you're new to training I suggest looking at the [machine learning glossary](https://developers.google.com/machine-learning/glossary/) by google. Now you can invoke training with the following command: ```bash ./train.py --cfg configs/example.yaml --output /tmp/mymodel.txt ``` This will initialize the pipeline and start training a new neural network. You can view progress by invoking tensorboard: ```bash tensorboard --logdir leelalogs ``` If you now point your browser at localhost:6006 you'll see the trainingprogress as the trainingsteps pass by. Have fun! ## Restoring models The training pipeline will automatically restore from a previous model if it exists in your `training:path` as configured by your yaml config. For initializing from a raw `weights.txt` file you can use `training/tf/net_to_model.py`, this will create a checkpoint for you. ## Supervised training Generating trainingdata from pgn files is currently broken and has low priority, feel free to create a PR. ## Building 2025-08 version. 1. Make sure `uv` and `justfile` are installed (plus `meson` and other stuff, potentially `protoc`). 2. `git submodule update` 3. `uv venv` (!important! do this before running meson; otherwise meson will build module for wrong python) 4. `uv sync` 5. `CXX=clang++ CC=clang uv run meson setup build/release/ --buildtype=release --native-file=native.ini` (clang is optional, should build fine with default compiler) 6. `just build-proto` 7. `meson compile -C build/release/` 8. `ln -s -T ../../build/release/_lczero_training.cpython-311-x86_64-linux-gnu.so src/lczero_training/_lczero_training.so` 14. Run it! `uv run tui --config docs/example.textproto` ================================================ FILE: csrc/loader/chunk_source/chunk_source.h ================================================ #pragma once #include #include #include #include #include "loader/frame_type.h" namespace lczero { namespace training { // Interface for providing training data chunks. // A chunk source provides access to one or more chunks of training data. // It's assumed that all chunks in a source for one group for sorting purposes, // therefore GetChunkSortKey() returns just one key for the entire source. This // allows to know the key before reading/indexing the chunks. class ChunkSource { public: virtual ~ChunkSource() = default; // Returns a sort key (e.g. filename or a timestamp). virtual std::string GetChunkSortKey() const = 0; // Returns the number of chunks in this source. virtual size_t GetChunkCount() const = 0; // Returns the data for the chunk at the given index. Returns std::nullopt if // the chunk could not be read or if the data size is not a multiple of the // expected frame size. virtual std::optional> GetChunkData(size_t index) = 0; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/chunk_source/chunk_source_view.h ================================================ #pragma once #include #include #include #include #include #include #include "loader/chunk_source/chunk_source.h" namespace lczero { namespace training { // ChunkSourceView provides a view over another ChunkSource. // It exposes a remapped subset/order of chunks defined by indices into the // underlying source. It does not own or copy the data; it forwards calls to // the wrapped source. class ChunkSourceView : public ChunkSource { public: // Constructs a view over an existing chunk source. The indices vector maps // local indices in the view to indices of the underlying source. ChunkSourceView(std::shared_ptr source, std::vector indices) : source_(std::move(source)), indices_(std::move(indices)) {} ~ChunkSourceView() override = default; private: std::string GetChunkSortKey() const override { return source_->GetChunkSortKey(); } size_t GetChunkCount() const override { return indices_.size(); } std::optional> GetChunkData(size_t index) override { if (index >= indices_.size()) return std::nullopt; const size_t src_index = static_cast(indices_[index]); return source_->GetChunkData(src_index); } std::shared_ptr source_; std::vector indices_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/chunk_source/debug_chunk_source.cc ================================================ #include "loader/chunk_source/debug_chunk_source.h" #include #include #include #include #include #include #include "absl/hash/hash.h" #include "absl/strings/str_format.h" namespace lczero { namespace training { DebugChunkSource::DebugChunkSource(uint64_t id, double mean_chunk_count) : id_(id), mean_chunk_count_(mean_chunk_count) {} std::string DebugChunkSource::GetChunkSortKey() const { return absl::StrFormat("%08" PRIu64, id_); } size_t DebugChunkSource::GetChunkCount() const { if (!cached_chunk_count_.has_value()) { std::mt19937_64 rng(id_); const double stddev = std::max(1.0, mean_chunk_count_ / 4.0); std::normal_distribution distribution(mean_chunk_count_, stddev); const double sampled = distribution(rng); const auto rounded = static_cast(std::llround(std::max(sampled, 1.0))); cached_chunk_count_ = static_cast(rounded); } return *cached_chunk_count_; } std::optional> DebugChunkSource::GetChunkData( size_t index) { const auto seed_pair = std::make_pair(id_, index); const uint64_t seed = static_cast( absl::Hash>{}(seed_pair)); std::mt19937_64 rng(seed); std::uniform_int_distribution frame_count_distribution(1, 200); const int frame_count = frame_count_distribution(rng); std::vector result(frame_count); for (int frame_index = 0; frame_index < frame_count; ++frame_index) { result[frame_index] = FrameType{}; result[frame_index].planes[0] = static_cast(id_); result[frame_index].planes[1] = static_cast(index); result[frame_index].planes[2] = static_cast(frame_index); } return result; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/chunk_source/debug_chunk_source.h ================================================ #pragma once #include #include #include #include #include #include "loader/chunk_source/chunk_source.h" namespace lczero { namespace training { // DebugChunkSource synthesizes deterministic pseudo-random chunks for loader // debugging. Each instance is identified by an integer id. The class produces // a chunk count sampled from a normal distribution with the provided mean and // mean / 4 standard deviation. The id serves as the seed, which keeps the // number of chunks stable across runs. Individual chunks contain a // pseudo-random number of FrameType frames (between one and 200) that are // generated on demand. The generation seed depends on both the source id and // chunk index. This lets shuffling logic exercise variable chunk sizes while // keeping the content reproducible. Each generated frame is zero-initialized, // but the first three entries of the planes array encode, respectively, the // source id, the chunk index, and the frame index within the chunk. This makes // it easy to reason about ordering and grouping when inspecting chunk payloads. class DebugChunkSource : public ChunkSource { public: DebugChunkSource(uint64_t id, double mean_chunk_count); private: std::string GetChunkSortKey() const override; size_t GetChunkCount() const override; std::optional> GetChunkData(size_t index) override; uint64_t id_; double mean_chunk_count_; mutable std::optional cached_chunk_count_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/chunk_source/rawfile_chunk_source.cc ================================================ #include "loader/chunk_source/rawfile_chunk_source.h" #include #include #include #include "trainingdata/trainingdata_v6.h" #include "utils/files.h" #include "utils/gz.h" namespace lczero { namespace training { RawFileChunkSource::RawFileChunkSource( const std::filesystem::path& filename, ChunkSourceLoaderConfig::FrameFormat frame_format) : filename_(filename), frame_format_(frame_format) {} RawFileChunkSource::~RawFileChunkSource() = default; std::string RawFileChunkSource::GetChunkSortKey() const { return std::filesystem::path(filename_).filename().string(); } size_t RawFileChunkSource::GetChunkCount() const { return 1; } std::optional> RawFileChunkSource::GetChunkData( size_t index) { if (index != 0) return std::nullopt; std::string data = ReadFileToString(filename_); if (data.empty()) return std::nullopt; const size_t input_size = frame_format_ == ChunkSourceLoaderConfig::V7TrainingData ? sizeof(V7TrainingData) : sizeof(V6TrainingData); if (data.size() % input_size != 0) { LOG(WARNING) << "File " << filename_ << " size " << data.size() << " is not a multiple of input frame size " << input_size; return std::nullopt; } const size_t num_frames = data.size() / input_size; std::vector result(num_frames); if (frame_format_ == ChunkSourceLoaderConfig::V7TrainingData) { std::memcpy(result.data(), data.data(), data.size()); } else { const auto* v6_data = reinterpret_cast(data.data()); for (size_t i = 0; i < num_frames; ++i) { std::memcpy(&result[i], &v6_data[i], sizeof(V6TrainingData)); } } return result; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/chunk_source/rawfile_chunk_source.h ================================================ #pragma once #include #include #include "loader/chunk_source/chunk_source.h" #include "proto/data_loader_config.pb.h" namespace lczero { namespace training { // A chunk source that reads a single (potentially gzipped) file as a single // chunk. class RawFileChunkSource : public ChunkSource { public: RawFileChunkSource(const std::filesystem::path& filename, ChunkSourceLoaderConfig::FrameFormat frame_format); ~RawFileChunkSource(); private: std::string GetChunkSortKey() const override; size_t GetChunkCount() const override; std::optional> GetChunkData(size_t index) override; std::string filename_; ChunkSourceLoaderConfig::FrameFormat frame_format_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/chunk_source/tar_chunk_source.cc ================================================ #include "loader/chunk_source/tar_chunk_source.h" #include #include #include #include #include #include #include #include #include #include #include #include "trainingdata/trainingdata_v6.h" #include "utils/gz.h" namespace lczero { namespace training { namespace { struct TarHeader { std::array name; std::array mode; std::array uid; std::array gid; std::array size; std::array mtime; std::array chksum; uint8_t typeflag; std::array linkname; std::array magic; std::array version; std::array uname; std::array gname; std::array devmajor; std::array devminor; std::array prefix; std::array padding; }; static_assert(sizeof(TarHeader) == 512, "TarHeader must be exactly 512 bytes"); uint64_t ParseOctal(const std::array& octal) { uint64_t value = 0; for (uint8_t digit : octal) { if (!digit) break; value = (value << 3) + (digit - '0'); } return value; } bool ReadExact(int fd, off_t offset, void* buffer, size_t size) { char* out = static_cast(buffer); size_t read_total = 0; while (read_total < size) { const ssize_t read_now = pread(fd, out + read_total, size - read_total, offset + read_total); if (read_now <= 0) return false; read_total += static_cast(read_now); } return true; } std::optional ReadGzipPrefix(int fd, off_t offset, size_t size, size_t max_bytes) { if (max_bytes == 0) return std::string(); z_stream strm = {}; if (inflateInit2(&strm, 16 + MAX_WBITS) != Z_OK) { return std::nullopt; } constexpr size_t kChunkSize = 16384; std::array input_buffer; std::array output_buffer; std::string output; output.reserve(std::min(max_bytes, kChunkSize)); size_t remaining = size; off_t current_offset = offset; bool finished = false; while (remaining > 0 && !finished && output.size() < max_bytes) { const size_t to_read = std::min(remaining, kChunkSize); if (!ReadExact(fd, current_offset, input_buffer.data(), to_read)) { inflateEnd(&strm); return std::nullopt; } remaining -= to_read; current_offset += static_cast(to_read); strm.next_in = reinterpret_cast(input_buffer.data()); strm.avail_in = static_cast(to_read); while (strm.avail_in > 0 && output.size() < max_bytes) { strm.next_out = reinterpret_cast(output_buffer.data()); strm.avail_out = kChunkSize; const int ret = inflate(&strm, Z_NO_FLUSH); if (ret == Z_STREAM_ERROR || ret == Z_NEED_DICT || ret == Z_DATA_ERROR || ret == Z_MEM_ERROR) { inflateEnd(&strm); return std::nullopt; } const size_t produced = kChunkSize - strm.avail_out; const size_t to_copy = std::min(produced, max_bytes - output.size()); output.append(output_buffer.data(), to_copy); if (ret == Z_STREAM_END) { finished = true; break; } } } inflateEnd(&strm); return output; } } // namespace TarChunkSource::TarChunkSource( const std::filesystem::path& filename, ChunkSourceLoaderConfig::FrameFormat frame_format) : path_(filename), filename_(filename.filename().string()), frame_format_(frame_format) { fd_ = open(path_.c_str(), O_RDONLY | O_CLOEXEC); if (fd_ < 0) { throw std::runtime_error( absl::StrCat("Failed to open tar file: ", path_.string(), ": ", std::strerror(errno))); } // Perform indexing during construction. Index(); } TarChunkSource::~TarChunkSource() { Close(); } void TarChunkSource::Close() { if (fd_ >= 0 && close(fd_) != 0) { PLOG(WARNING) << "Failed to close tar file descriptor for " << path_; } fd_ = -1; } std::string TarChunkSource::GetChunkSortKey() const { return filename_; } void TarChunkSource::Index() { assert(files_.empty()); off_t offset = 0; while (true) { TarHeader header; if (!ReadExact(fd_, offset, &header, sizeof(header))) { LOG(WARNING) << "Truncated tar file: " << filename_; break; } offset += sizeof(header); if (header.name[0] == '\0') break; // End of file. switch (header.typeflag) { case '5': // Directory continue; case '0': // Regular file break; default: LOG(WARNING) << "Unsupported tar header type: " << header.typeflag; continue; } std::string_view fname(const_cast(header.name.data())); const std::filesystem::path filepath = std::filesystem::path(fname); const off_t file_offset = offset; const size_t size = ParseOctal(header.size); const size_t padded_size = ((size + 511) / 512) * 512; offset = file_offset + static_cast(padded_size); if (filepath.filename() == "LICENSE") continue; files_.push_back({file_offset, size, filepath.extension() == ".gz"}); } LOG(INFO) << "Read " << files_.size() << " entries from " << filename_; } size_t TarChunkSource::GetChunkCount() const { return files_.size(); } std::optional> TarChunkSource::GetChunkData( size_t index) { if (index >= files_.size()) { throw std::out_of_range("File index out of range"); } const auto& file_entry = files_[index]; std::string content(file_entry.size, '\0'); if (!ReadExact(fd_, file_entry.offset, content.data(), file_entry.size)) { return std::nullopt; } if (file_entry.is_gzip) { try { content = GunzipBuffer(content); } catch (const GunzipError& e) { return std::nullopt; } } if (content.empty()) return std::nullopt; const size_t input_size = frame_format_ == ChunkSourceLoaderConfig::V7TrainingData ? sizeof(V7TrainingData) : sizeof(V6TrainingData); if (content.size() % input_size != 0) { LOG(WARNING) << "Chunk " << index << " from " << filename_ << " size " << content.size() << " is not a multiple of input frame size " << input_size; return std::nullopt; } const size_t num_frames = content.size() / input_size; std::vector result(num_frames); if (frame_format_ == ChunkSourceLoaderConfig::V7TrainingData) { std::memcpy(result.data(), content.data(), content.size()); } else { const auto* v6_data = reinterpret_cast(content.data()); for (size_t i = 0; i < num_frames; ++i) { std::memcpy(&result[i], &v6_data[i], sizeof(V6TrainingData)); } } return result; } std::optional TarChunkSource::GetChunkPrefix(size_t index, size_t max_bytes) { if (index >= files_.size()) { throw std::out_of_range("File index out of range"); } const auto& file_entry = files_[index]; if (file_entry.is_gzip) { return ReadGzipPrefix(fd_, file_entry.offset, file_entry.size, max_bytes); } const size_t to_read = std::min(file_entry.size, max_bytes); std::string content(to_read, '\0'); if (!ReadExact(fd_, file_entry.offset, content.data(), to_read)) { return std::nullopt; } return content; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/chunk_source/tar_chunk_source.h ================================================ #pragma once #include #include #include #include #include "loader/chunk_source/chunk_source.h" #include "proto/data_loader_config.pb.h" namespace lczero { namespace training { // A chunk source that reads a tar archive and provides access to its files as // chunks. Each file in the tar is treated as a separate chunk. class TarChunkSource : public ChunkSource { public: TarChunkSource(const std::filesystem::path& filename, ChunkSourceLoaderConfig::FrameFormat frame_format); ~TarChunkSource() override; std::string GetChunkSortKey() const override; size_t GetChunkCount() const override; std::optional> GetChunkData(size_t index) override; std::optional GetChunkPrefix(size_t index, size_t max_bytes); private: // Performs one-time indexing during construction. Not part of the interface. void Index(); struct FileEntry { off_t offset; size_t size; bool is_gzip; }; void Close(); int fd_ = -1; std::vector files_; std::filesystem::path path_; std::string filename_; ChunkSourceLoaderConfig::FrameFormat frame_format_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/data_loader.cc ================================================ #include "loader/data_loader.h" #include #include #include #include #include #include #include #include #include "loader/data_loader_metrics.h" namespace lczero { namespace training { DataLoaderConfig DataLoader::ParseConfig(const std::string& serialized_config) { DataLoaderConfig config; config.ParseFromString(serialized_config); return config; } DataLoader::DataLoader(const std::string& serialized_data_loader_config) : metrics_aggregator_( [](DataLoaderMetricsProto& m) { m.Clear(); }, [](DataLoaderMetricsProto& dest, const DataLoaderMetricsProto& src) { UpdateFrom(dest, src); }) { DataLoaderConfig config = ParseConfig(serialized_data_loader_config); AddStages(config); BuildOutputMapping(config); LOG(INFO) << "DataLoader initialized with " << stage_registry_.size() << " stage(s) and " << outputs_.size() << " output(s)."; } DataLoader::~DataLoader() { Stop(); } void DataLoader::AddStages(const std::string& serialized_data_loader_config) { AddStages(ParseConfig(serialized_data_loader_config)); } void DataLoader::AddStages(const DataLoaderConfig& config) { for (const auto& stage_config : config.stage()) AddStage(stage_config); for (const auto& stage_config : config.stage()) SetStageInputs(stage_config); } void DataLoader::AddStage(const StageConfig& stage_config) { if (started_) { throw std::runtime_error("Cannot add stages after DataLoader has started."); } auto stage = CreateStage(stage_config); if (!stage_config.has_name()) { throw std::runtime_error("Stage configuration is missing name."); } LOG(INFO) << "Adding stage '" << stage_config.name() << "'."; stage_registry_.AddStage(stage_config.name(), std::move(stage)); } void DataLoader::SetStageInputs(const StageConfig& stage_config) { // Resolve input names to queue pointers. std::vector input_queues; input_queues.reserve(stage_config.input_size()); for (const auto& input_name : stage_config.input()) { QueueBase* queue = stage_registry_.GetStageOutput(input_name); if (!queue) { throw std::runtime_error(absl::StrCat("Input stage '", input_name, "' not found for stage '", stage_config.name(), "'.")); } input_queues.push_back(queue); } // Wire up inputs. auto it = absl::c_find_if(stage_registry_.stages(), [&](const auto& p) { return p.first == stage_config.name(); }); if (it == stage_registry_.stages().end()) { throw std::runtime_error(absl::StrCat("Stage '", stage_config.name(), "' not found in registry.")); } it->second->SetInputs(absl::MakeSpan(input_queues)); } void DataLoader::Start() { if (started_) { throw std::runtime_error("DataLoader has already been started."); } for (auto& [name, stage] : stage_registry_.stages()) { LOG(INFO) << "Starting stage '" << name << "'."; stage->Start(); } metrics_thread_ = std::jthread( [this](std::stop_token stop_token) { MetricsThread(stop_token); }); started_ = true; stopped_ = false; LOG(INFO) << "DataLoader started."; } TensorTuple DataLoader::GetNext(std::string_view alias) { Queue* q = GetOutputQueue(alias); if (!q) { std::string alias_list = absl::StrJoin( outputs_, ", ", [](std::string* out, const auto& p) { absl::StrAppend(out, p.first); }); throw std::runtime_error(absl::StrCat("Unknown DataLoader output: '", alias, "'. Available outputs: [", alias_list, "].")); } return q->Get(); } std::optional DataLoader::MaybeGetNext(std::string_view alias) { Queue* q = GetOutputQueue(alias); if (!q) { std::string alias_list = absl::StrJoin( outputs_, ", ", [](std::string* out, const auto& p) { absl::StrAppend(out, p.first); }); throw std::runtime_error(absl::StrCat("Unknown DataLoader output: '", alias, "'. Available outputs: [", alias_list, "].")); } return q->MaybeGet(); } void DataLoader::Stop() { if (stopped_) return; LOG(INFO) << "Stopping DataLoader."; if (metrics_thread_.joinable()) { metrics_thread_.request_stop(); metrics_thread_.join(); } for (auto& [name, stage] : stage_registry_.stages()) { LOG(INFO) << "Stopping stage '" << name << "'."; stage->Stop(); } stopped_ = true; started_ = false; LOG(INFO) << "DataLoader stopped."; } std::pair DataLoader::GetBucketMetrics( int time_period, bool include_pending) const { auto [metrics, duration] = metrics_aggregator_.GetBucketMetrics( static_cast(time_period), include_pending ? std::make_optional(std::chrono::steady_clock::now()) : std::nullopt); float duration_seconds = std::chrono::duration(duration).count(); return {metrics.OutputAsString(), duration_seconds}; } std::pair DataLoader::GetAggregateEndingNow( float duration_seconds, bool include_pending) const { std::chrono::nanoseconds duration_ns = std::isinf(duration_seconds) ? std::chrono::nanoseconds::max() : std::chrono::nanoseconds( static_cast(duration_seconds * 1e9)); auto [metrics, duration] = metrics_aggregator_.GetAggregateEndingNow( duration_ns, include_pending ? std::make_optional(std::chrono::steady_clock::now()) : std::nullopt); float result_duration_seconds = std::chrono::duration(duration).count(); return {metrics.OutputAsString(), result_duration_seconds}; } void DataLoader::MetricsThread(std::stop_token stop_token) { while (!stop_token.stop_requested()) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); DataLoaderMetricsProto metrics; for (auto& [name, stage] : stage_registry_.stages()) { StageMetricProto stage_metric = stage->FlushMetrics(); stage_metric.set_name(name); *metrics.add_stage_metrics() = std::move(stage_metric); } metrics_aggregator_.RecordMetrics(std::move(metrics)); metrics_aggregator_.Advance(std::chrono::steady_clock::now()); } LOG(INFO) << "Metrics thread stopping."; } std::vector> DataLoader::SendControlMessage(const StageControlRequest& request) { std::vector> responses; responses.reserve(stage_registry_.size()); for (auto& [name, stage] : stage_registry_.stages()) { std::optional response = stage->Control(request); if (response.has_value()) { responses.emplace_back(name, std::move(*response)); } } return responses; } void DataLoader::BuildOutputMapping(const DataLoaderConfig& config) { outputs_.clear(); outputs_.reserve(config.output_size()); for (const auto& out_spec : config.output()) { auto it = absl::c_find(out_spec, ':'); std::string alias = it == out_spec.end() ? std::string("") : std::string(out_spec.begin(), it); std::string stage_name = it == out_spec.end() ? std::string(out_spec) : std::string(it + 1, out_spec.end()); // Ensure alias is unique. if (absl::c_find_if(outputs_, [&](const auto& p) { return p.first == alias; }) != outputs_.end()) { throw std::runtime_error( absl::StrCat("Duplicate output alias specified: ", alias)); } Queue* queue = stage_registry_.GetTypedStageOutput(stage_name); if (queue == nullptr) { throw std::runtime_error( absl::StrCat("Output stage not found or wrong type: ", stage_name)); } outputs_.emplace_back(std::move(alias), queue); } } Queue* DataLoader::GetOutputQueue(std::string_view alias) const { auto it = absl::c_find_if(outputs_, [&](const auto& p) { return p.first == alias; }); if (it == outputs_.end()) return nullptr; return it->second; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/data_loader.h ================================================ #pragma once #include #include #include #include #include #include #include "loader/stages/stage_factory.h" #include "proto/data_loader_config.pb.h" #include "proto/stage_control.pb.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/exponential_aggregator.h" #include "utils/queue.h" #include "utils/tensor.h" namespace lczero { namespace training { using TensorDict = std::vector>>; class DataLoader { public: using MetricsAggregator = ExponentialAggregator; explicit DataLoader(const std::string& serialized_data_loader_config); ~DataLoader(); void Start(); TensorTuple GetNext(std::string_view alias); std::optional MaybeGetNext(std::string_view alias); void Stop(); std::pair GetBucketMetrics(int time_period, bool include_pending) const; std::pair GetAggregateEndingNow( float duration_seconds, bool include_pending) const; void AddStages(const DataLoaderConfig& config); void AddStages(const std::string& serialized_data_loader_config); std::vector> SendControlMessage( const StageControlRequest& request); private: void AddStage(const StageConfig& stage_config); void SetStageInputs(const StageConfig& stage_config); static DataLoaderConfig ParseConfig( const std::string& serialized_data_loader_config); void MetricsThread(std::stop_token stop_token); void BuildOutputMapping(const DataLoaderConfig& config); Queue* GetOutputQueue(std::string_view alias) const; StageRegistry stage_registry_; std::vector*>> outputs_; MetricsAggregator metrics_aggregator_; std::jthread metrics_thread_; bool started_ = false; bool stopped_ = false; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/data_loader_metrics.cc ================================================ // ABOUTME: Implementation of UpdateFrom functions for data loader metric // protobuf messages. ABOUTME: Handles aggregation of generic stage metrics. #include "loader/data_loader_metrics.h" #include #include "absl/strings/string_view.h" #include "utils/metrics/statistics_metric.h" namespace lczero { namespace training { namespace { template ProtoT* FindByName(std::vector* entries, absl::string_view name) { for (auto& entry : *entries) { if (entry.name() == name) return &entry; } return nullptr; } } // namespace void UpdateFrom(QueueMetricProto& dest, const QueueMetricProto& src) { if (src.has_name()) dest.set_name(src.name()); dest.set_put_count(dest.put_count() + src.put_count()); dest.set_get_count(dest.get_count() + src.get_count()); dest.set_drop_count(dest.drop_count() + src.drop_count()); UpdateFrom(*dest.mutable_queue_fullness(), src.queue_fullness()); if (src.has_queue_capacity()) dest.set_queue_capacity(src.queue_capacity()); } void UpdateFrom(CountMetricProto& dest, const CountMetricProto& src) { if (src.has_name()) dest.set_name(src.name()); if (src.has_count()) dest.set_count(dest.count() + src.count()); } void UpdateFrom(GaugeMetricProto& dest, const GaugeMetricProto& src) { if (src.has_name()) dest.set_name(src.name()); if (src.has_value()) dest.set_value(src.value()); if (src.has_capacity()) dest.set_capacity(src.capacity()); } void UpdateFrom(StageMetricProto& dest, const StageMetricProto& src) { if (src.has_name()) dest.set_name(src.name()); for (const auto& load_metrics : src.load_metrics()) { LoadMetricProto* dest_load = load_metrics.has_name() ? FindByName(dest.mutable_load_metrics(), load_metrics.name()) : nullptr; if (dest_load == nullptr) { dest_load = dest.add_load_metrics(); } UpdateFrom(*dest_load, load_metrics); } for (const auto& queue_metrics : src.queue_metrics()) { QueueMetricProto* dest_queue = queue_metrics.has_name() ? FindByName(dest.mutable_queue_metrics(), queue_metrics.name()) : nullptr; if (dest_queue == nullptr) { dest_queue = dest.add_queue_metrics(); } UpdateFrom(*dest_queue, queue_metrics); } for (const auto& count_metrics : src.count_metrics()) { CountMetricProto* dest_count = count_metrics.has_name() ? FindByName(dest.mutable_count_metrics(), count_metrics.name()) : nullptr; if (dest_count == nullptr) { dest_count = dest.add_count_metrics(); } UpdateFrom(*dest_count, count_metrics); } for (const auto& gauge_metrics : src.gauge_metrics()) { GaugeMetricProto* dest_gauge = gauge_metrics.has_name() ? FindByName(dest.mutable_gauge_metrics(), gauge_metrics.name()) : nullptr; if (dest_gauge == nullptr) { dest_gauge = dest.add_gauge_metrics(); } UpdateFrom(*dest_gauge, gauge_metrics); } for (const auto& statistics_metrics : src.statistics_metrics()) { StatisticsProtoDouble* dest_stats = statistics_metrics.has_name() ? FindByName(dest.mutable_statistics_metrics(), statistics_metrics.name()) : nullptr; if (dest_stats == nullptr) { dest_stats = dest.add_statistics_metrics(); } UpdateFrom(*dest_stats, statistics_metrics); } if (src.has_last_chunk_key()) dest.set_last_chunk_key(src.last_chunk_key()); if (src.has_anchor()) dest.set_anchor(src.anchor()); } void UpdateFrom(DataLoaderMetricsProto& dest, const DataLoaderMetricsProto& src) { for (const auto& stage_metrics : src.stage_metrics()) { StageMetricProto* dest_stage = stage_metrics.has_name() ? FindByName(dest.mutable_stage_metrics(), stage_metrics.name()) : nullptr; if (dest_stage == nullptr) { dest_stage = dest.add_stage_metrics(); } UpdateFrom(*dest_stage, stage_metrics); } } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/data_loader_metrics.h ================================================ // ABOUTME: Header for UpdateFrom functions for data loader metric protobuf // messages. ABOUTME: Declares functions for aggregating FilePathProvider and // DataLoader metrics. #pragma once #include "absl/strings/string_view.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/load_metric.h" #include "utils/metrics/statistics_metric.h" #include "utils/queue.h" namespace lczero { namespace training { void UpdateFrom(QueueMetricProto& dest, const QueueMetricProto& src); void UpdateFrom(CountMetricProto& dest, const CountMetricProto& src); void UpdateFrom(GaugeMetricProto& dest, const GaugeMetricProto& src); void UpdateFrom(StageMetricProto& dest, const StageMetricProto& src); void UpdateFrom(DataLoaderMetricsProto& dest, const DataLoaderMetricsProto& src); template QueueMetricProto MetricsFromQueue(absl::string_view name, Queue& queue) { QueueMetricProto result; result.set_name(std::string(name)); result.set_put_count(queue.GetTotalPutCount(true)); result.set_get_count(queue.GetTotalGetCount(true)); result.set_drop_count(queue.GetTotalDropCount(true)); AddSample(*result.mutable_queue_fullness(), queue.Size()); result.set_queue_capacity(queue.Capacity()); return result; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/data_loader_test.cc ================================================ #include "loader/data_loader.h" #include #include namespace lczero { namespace training { TEST(DataLoaderTest, AllowsNoOutputsConfigured) { DataLoaderConfig config; auto* file_stage = config.add_stage(); file_stage->set_name("file_path_provider"); file_stage->mutable_file_path_provider()->set_directory("."); EXPECT_NO_THROW(DataLoader(config.OutputAsString())); } TEST(DataLoaderTest, ThrowsOnDuplicateStageName) { DataLoaderConfig config; auto* first_stage = config.add_stage(); first_stage->set_name("duplicate"); first_stage->mutable_file_path_provider()->set_directory("."); auto* second_stage = config.add_stage(); second_stage->set_name("duplicate"); second_stage->mutable_file_path_provider()->set_directory("."); EXPECT_THROW(DataLoader(config.OutputAsString()), std::runtime_error); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/frame_type.h ================================================ /* This file is part of Leela Chess Zero. Copyright (C) 2025 The LCZero Authors Leela Chess is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Leela Chess is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Leela Chess. If not, see . Additional permission under GNU GPL version 3 section 7 If you modify this Program, or any covered work, by linking or combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA Toolkit and the NVIDIA CUDA Deep Neural Network library (or a modified version of those libraries), containing parts covered by the terms of the respective license agreement, the licensors of this Program grant you additional permission to convey the resulting work. */ #pragma once #include "trainingdata/trainingdata_v7.h" namespace lczero { namespace training { using FrameType = V7TrainingData; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/loader_main.cpp ================================================ #include #include #include #include #include #include #include #include #include #include #include #include "data_loader.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/printer.h" ABSL_FLAG(std::string, directory, "/home/crem/tmp/2025-07/lczero-training/", "Directory to watch for training data files"); ABSL_FLAG(size_t, chunk_pool_size, 1000000, "Size of the chunk shuffle buffer"); ABSL_FLAG(size_t, reservoir_size_per_thread, 1000000, "Size of the reservoir for frame sampling per thread"); namespace lczero { namespace training { void Run() { DataLoaderConfig config; // Configure file path provider stage. auto* file_stage = config.add_stage(); file_stage->set_name("file_path_provider"); auto* file_path_provider = file_stage->mutable_file_path_provider(); file_path_provider->set_directory(absl::GetFlag(FLAGS_directory)); // Configure chunk source loader stage. auto* chunk_loader_stage = config.add_stage(); chunk_loader_stage->set_name("chunk_source_loader"); chunk_loader_stage->add_input(file_stage->name()); chunk_loader_stage->mutable_chunk_source_loader(); // Configure shuffling chunk pool stage. auto* chunk_pool_stage = config.add_stage(); chunk_pool_stage->set_name("shuffling_chunk_pool"); chunk_pool_stage->add_input(chunk_loader_stage->name()); auto* shuffling_chunk_pool = chunk_pool_stage->mutable_shuffling_chunk_pool(); shuffling_chunk_pool->set_chunk_pool_size( absl::GetFlag(FLAGS_chunk_pool_size)); // Configure chunk unpacker stage. auto* unpacker_stage = config.add_stage(); unpacker_stage->set_name("chunk_unpacker"); unpacker_stage->add_input(chunk_pool_stage->name()); unpacker_stage->mutable_chunk_unpacker(); // Configure shuffling frame sampler stage. auto* sampler_stage = config.add_stage(); sampler_stage->set_name("shuffling_frame_sampler"); sampler_stage->add_input(unpacker_stage->name()); auto* shuffling_frame_sampler = sampler_stage->mutable_shuffling_frame_sampler(); shuffling_frame_sampler->set_reservoir_size_per_thread( absl::GetFlag(FLAGS_reservoir_size_per_thread)); // Configure tensor generator stage. auto* tensor_stage = config.add_stage(); tensor_stage->set_name("tensor_generator"); tensor_stage->add_input(sampler_stage->name()); tensor_stage->mutable_tensor_generator(); // Serialize config and create loader config.add_output("tensor_generator"); std::string config_string = config.OutputAsString(); DataLoader loader(config_string); return; std::atomic batch_count = 0; auto start_time = absl::Now(); // Start logging thread std::atomic should_stop{false}; std::thread logging_thread([&]() { while (!should_stop.load()) { std::this_thread::sleep_for(std::chrono::seconds(1)); auto current_time = absl::Now(); auto total_elapsed = current_time - start_time; double rate = batch_count / absl::ToDoubleSeconds(total_elapsed); auto [stats_string, duration] = loader.GetBucketMetrics(0, false); // k1Second = 0 DataLoaderMetricsProto metrics; metrics.ParseFromString(stats_string); std::string metrics_json = metrics.OutputAsJson(); LOG(INFO) << absl::StrCat("Processed ", batch_count.load(), " batches in ", absl::ToDoubleSeconds(total_elapsed), "s. Rate: ", absl::StrFormat("%.2f", rate), " batches/sec. ", "Metrics: ", metrics_json); } }); while (true) { TensorTuple batch = loader.GetNext("train"); ++batch_count; } } } // namespace training } // namespace lczero int main(int argc, char* argv[]) { absl::ParseCommandLine(argc, argv); absl::InitializeLog(); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); lczero::training::Run(); return 0; } ================================================ FILE: csrc/loader/pybind_module.cc ================================================ // ABOUTME: PyBind11 binding module exposing C++ DataLoader to Python. // ABOUTME: Handles configuration conversion and tensor memory management for // numpy arrays. #include #include #include #include #include #include #include #include #include #include "loader/data_loader.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/chunk_unpacker.h" #include "loader/stages/file_path_provider.h" #include "loader/stages/shuffling_chunk_pool.h" #include "loader/stages/shuffling_frame_sampler.h" #include "loader/stages/tensor_generator.h" #include "utils/tensor.h" namespace py = pybind11; namespace lczero { namespace training { namespace { std::string SerializePyProto(const py::handle& obj, const char* expected_type) { if (py::isinstance(obj)) { return obj.cast(); } if (py::hasattr(obj, "SerializeToString")) { py::object bytes_obj = obj.attr("SerializeToString")(); return bytes_obj.cast().cast(); } throw std::invalid_argument(std::string("Expected ") + expected_type + " protobuf message or bytes."); } template ProtoT ParsePyProto(const py::handle& obj, const char* expected_type) { ProtoT proto; proto.ParseFromString(SerializePyProto(obj, expected_type)); return proto; } py::object MakePythonProto(const char* module_name, const char* message_name, const std::string& serialized) { py::object message_cls = py::module::import(module_name).attr(message_name); py::object message_obj = message_cls(); message_obj.attr("ParseFromString")(py::bytes(serialized)); return message_obj; } } // namespace // Helper function to convert TensorBase to numpy array using buffer protocol. py::array tensor_to_numpy(std::unique_ptr tensor) { // Extract raw pointer and release ownership from unique_ptr. TensorBase* raw_tensor = tensor.release(); // Create numpy array with take_ownership policy. // This transfers memory ownership to Python/numpy. return py::array( py::dtype(raw_tensor->py_format()), raw_tensor->shape(), raw_tensor->strides(), raw_tensor->data(), py::cast(raw_tensor, py::return_value_policy::take_ownership)); } // Convert TensorTuple to tuple of numpy arrays. py::tuple tensor_tuple_to_numpy_tuple(TensorTuple tensor_tuple) { py::tuple result(tensor_tuple.size()); for (size_t i = 0; i < tensor_tuple.size(); ++i) { result[i] = tensor_to_numpy(std::move(tensor_tuple[i])); } return result; } PYBIND11_MODULE(_lczero_training, m) { absl::InitializeLog(); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); m.doc() = "Leela Chess Zero training data loader"; // Configuration is now handled via protobuf serialized strings // Expose the main DataLoader class. py::class_(m, "DataLoader") .def(py::init([](py::object config) { std::string config_string = SerializePyProto(config, "DataLoaderConfig"); py::gil_scoped_release release; return new DataLoader(config_string); }), py::arg("config"), "Create DataLoader from DataLoaderConfig proto or bytes.") .def( "add_stages", [](DataLoader& self, py::object config) { std::string config_string = SerializePyProto(config, "DataLoaderConfig"); py::gil_scoped_release release; self.AddStages(config_string); }, py::arg("config"), "Append stages from DataLoaderConfig proto or bytes.") .def( "send_control_message", [](DataLoader& self, py::object request) { StageControlRequest control_request = ParsePyProto(request, "StageControlRequest"); auto responses = [&]() { py::gil_scoped_release release; return self.SendControlMessage(control_request); }(); py::list result; for (const auto& [stage_name, response] : responses) { py::object response_obj = MakePythonProto( "proto.stage_control_pb2", "StageControlResponse", response.OutputAsString()); result.append(py::make_tuple(stage_name, response_obj)); } return result; }, py::arg("request"), "Send StageControlRequest to stages and return (stage_name, " "StageControlResponse) tuples.") .def( "get_next", [](DataLoader& self, const std::string& alias) { return tensor_tuple_to_numpy_tuple([&] { py::gil_scoped_release release; return self.GetNext(alias); }()); }, py::arg("alias") = "", "Get next batch for the given output alias (default empty) as a " "tuple of numpy arrays") .def( "maybe_get_next", [](DataLoader& self, const std::string& alias) -> std::optional { auto result = [&] { py::gil_scoped_release release; return self.MaybeGetNext(alias); }(); if (result.has_value()) { return tensor_tuple_to_numpy_tuple(std::move(*result)); } return std::nullopt; }, py::arg("alias") = "", "Non-blocking get next batch for the given output alias (default " "empty). Returns tuple of numpy arrays or None if no data available") .def( "get_bucket_metrics", [](const DataLoader& self, int time_period, bool include_pending) { auto [metrics, duration] = [&] { py::gil_scoped_release release; return self.GetBucketMetrics(time_period, include_pending); }(); return py::make_tuple(py::bytes(metrics), duration); }, "Get serialized metrics for bucket and duration as (bytes, float)") .def( "get_aggregate_ending_now", [](const DataLoader& self, float duration_seconds, bool include_pending) { auto [metrics, duration] = [&] { py::gil_scoped_release release; return self.GetAggregateEndingNow(duration_seconds, include_pending); }(); return py::make_tuple(py::bytes(metrics), duration); }, "Get serialized metrics for aggregate duration and actual duration " "as (bytes, float)") .def("start", &DataLoader::Start, "Start the data loader processing") .def("stop", &DataLoader::Stop, "Stop the data loader"); // Expose TensorBase for potential advanced usage. py::class_(m, "TensorBase") .def("shape", &TensorBase::shape, py::return_value_policy::reference) .def("strides", &TensorBase::strides, py::return_value_policy::reference) .def("element_size", &TensorBase::element_size) .def("py_format", &TensorBase::py_format); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_rescorer.cc ================================================ #include "loader/stages/chunk_rescorer.h" #include #include #include "absl/base/call_once.h" #include "absl/log/log.h" #include "chess/board.h" #include "loader/data_loader_metrics.h" namespace lczero { namespace training { namespace { void V6ToV7(std::span data, float theta = 5.0f / 6.0f) { if (data.empty()) return; const float beta = 1.0f - theta; float st_q = 0.0f; float st_d = 0.0f; // Iterate backwards to calculate EMA and lookaheads in one go. for (size_t i = data.size(); i-- > 0;) { FrameType& item = data[i]; // For Q, we operate in an alternating sign domain. const float sign = (i % 2 != 0) ? -1.0f : 1.0f; const float cur_q = item.root_q * sign; if (i == data.size() - 1) { st_q = cur_q; st_d = item.root_d; } else { st_q = theta * st_q + beta * cur_q; st_d = theta * st_d + beta * item.root_d; } item.version = 7; item.q_st = st_q * sign; item.d_st = std::max(st_d, 0.0f); // Handle lookahead indices safely. auto get_idx = [&](size_t offset) { return (i + offset < data.size()) ? data[i + offset].played_idx : 65535; }; item.opp_played_idx = get_idx(1); item.next_played_idx = get_idx(2); } } } // namespace ChunkRescorer::ChunkRescorer(const ChunkRescorerConfig& config) : SingleInputStage(config), SingleOutputStage(config.output()), syzygy_paths_(config.syzygy_paths()), dist_temp_(config.dist_temp()), dist_offset_(config.dist_offset()), dtz_boost_(config.dtz_boost()), new_input_format_(config.new_input_format()), thread_pool_(config.threads(), ThreadPoolOptions{}), st_q_theta_(config.st_q_theta()) { static absl::once_flag bitboards_initialized_flag; absl::call_once(bitboards_initialized_flag, InitializeMagicBitboards); if (config.has_deblunder_threshold() && config.has_deblunder_width()) { RescorerDeblunderSetup(config.deblunder_threshold(), config.deblunder_width()); } if (config.has_gaviota_paths()) { RescorerGaviotaSetup(std::string(config.gaviota_paths())); } LOG(INFO) << "Initializing ChunkRescorer with " << config.threads() << " worker thread(s)"; thread_contexts_.reserve(config.threads()); for (size_t i = 0; i < config.threads(); ++i) { thread_contexts_.push_back(std::make_unique()); } } ChunkRescorer::~ChunkRescorer() { Stop(); } void ChunkRescorer::InitializeTablebase() { if (tablebase_initialized_) { return; } LOG(INFO) << "ChunkRescorer initializing Syzygy tablebase with paths '" << syzygy_paths_ << "'."; tablebase_initialized_ = tablebase_.init(syzygy_paths_); if (tablebase_initialized_) { LOG(INFO) << "ChunkRescorer Syzygy max cardinality: " << tablebase_.max_cardinality(); } else { LOG(WARNING) << "ChunkRescorer failed to initialize Syzygy tablebase; " "rescoring will continue without tablebase lookups."; } } void ChunkRescorer::Start() { LOG(INFO) << "Starting ChunkRescorer worker threads."; InitializeTablebase(); for (size_t i = 0; i < thread_contexts_.size(); ++i) { thread_pool_.Enqueue([this, i](std::stop_token stop_token) { Worker(stop_token, thread_contexts_[i].get()); }); } } void ChunkRescorer::Stop() { if (thread_pool_.stop_token().stop_requested()) return; LOG(INFO) << "Stopping ChunkRescorer."; thread_pool_.Shutdown(); output_queue()->Close(); LOG(INFO) << "ChunkRescorer stopped."; } void ChunkRescorer::Worker(std::stop_token stop_token, ThreadContext* context) { auto producer = output_queue()->CreateProducer(); try { while (true) { TrainingChunk chunk = [&]() { LoadMetricPauser pauser(context->load_metric_updater); return input_queue()->Get(stop_token); }(); try { chunk.frames = RescoreTrainingData( chunk.frames, &tablebase_, dist_temp_, dist_offset_, dtz_boost_, new_input_format_); if (chunk.frames.at(0).version == 6) V6ToV7(chunk.frames, st_q_theta_); LoadMetricPauser pauser(context->load_metric_updater); producer.Put(std::move(chunk), stop_token); } catch (const std::exception& exception) { failed_rescores_.fetch_add(1, std::memory_order_acq_rel); LOG(ERROR) << "ChunkRescorer failed to rescore chunk: " << exception.what() << "; sort_key=" << chunk.sort_key << "; index_within_sort_key=" << chunk.index_within_sort_key << "; global_index=" << chunk.global_index << "; use_count=" << chunk.use_count << "; frame_count=" << chunk.frames.size(); continue; } } } catch (const QueueClosedException&) { LOG(INFO) << "ChunkRescorer worker stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "ChunkRescorer worker stopping, request cancelled."; } } StageMetricProto ChunkRescorer::FlushMetrics() { StageMetricProto stage_metric; LoadMetricProto aggregated_load; aggregated_load.set_name("load"); for (const auto& context : thread_contexts_) { UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(aggregated_load); auto* failed = stage_metric.add_count_metrics(); failed->set_name("failed_rescores"); failed->set_count(failed_rescores_.exchange(0, std::memory_order_acq_rel)); *stage_metric.add_queue_metrics() = MetricsFromQueue("output", *output_queue()); return stage_metric; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_rescorer.h ================================================ // ABOUTME: Stage that rescales training chunks using Syzygy tablebases. // ABOUTME: Adjusts frame metadata by invoking the classic LCZero rescorer. #pragma once #include #include #include #include #include #include #include "libs/lc0/src/syzygy/syzygy.h" #include "libs/lc0/src/trainingdata/rescorer.h" #include "loader/frame_type.h" #include "loader/stages/stage.h" #include "loader/stages/training_chunk.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/load_metric.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Stage that takes TrainingChunk objects, applies tablebase-based rescoring and // forwards the updated chunks downstream. class ChunkRescorer : public SingleInputStage, public SingleOutputStage { public: using InputType = TrainingChunk; using OutputType = TrainingChunk; explicit ChunkRescorer(const ChunkRescorerConfig& config); ~ChunkRescorer() override; void Start() override; void Stop() override; StageMetricProto FlushMetrics() override; private: struct ThreadContext { LoadMetricUpdater load_metric_updater; }; void Worker(std::stop_token stop_token, ThreadContext* context); void InitializeTablebase(); SyzygyTablebase tablebase_; bool tablebase_initialized_ = false; std::string syzygy_paths_; float dist_temp_; float dist_offset_; float dtz_boost_; int new_input_format_; ThreadPool thread_pool_; std::vector> thread_contexts_; std::atomic failed_rescores_{0}; float st_q_theta_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_rescorer_test.cc ================================================ #include "loader/stages/chunk_rescorer.h" #include #include #include #include "gtest/gtest.h" #include "loader/stages/training_chunk.h" #include "proto/data_loader_config.pb.h" #include "utils/queue.h" namespace lczero { namespace training { namespace { template class PassthroughStage : public Stage { public: explicit PassthroughStage(Queue* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return StageMetricProto(); } QueueBase* GetOutput(std::string_view name = "") override { (void)name; return queue_; } void SetInputs(absl::Span inputs) override { if (!inputs.empty()) { throw std::runtime_error("PassthroughStage expects no inputs"); } } private: Queue* queue_; }; } // namespace class ChunkRescorerTest : public ::testing::Test { protected: void SetUp() override { input_queue_ = std::make_unique>(10); config_.set_threads(1); config_.mutable_output()->set_queue_capacity(10); config_.set_syzygy_paths(""); config_.set_dist_temp(0.75f); config_.set_dist_offset(0.1f); config_.set_dtz_boost(0.2f); config_.set_new_input_format(-1); } TrainingChunk MakeChunk(std::vector frames, std::string sort_key = "alpha", size_t index = 3, uint32_t use = 7) { TrainingChunk chunk; chunk.sort_key = std::move(sort_key); chunk.index_within_sort_key = index; chunk.use_count = use; chunk.frames = std::move(frames); return chunk; } std::unique_ptr> input_queue_; ChunkRescorerConfig config_; }; TEST_F(ChunkRescorerTest, HandlesInputQueueClosure) { ChunkRescorer rescorer(config_); rescorer.SetInputs({input_queue_.get()}); rescorer.Start(); input_queue_->Close(); EXPECT_THROW(rescorer.output_queue()->Get(), QueueClosedException); rescorer.Stop(); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_source_loader.cc ================================================ #include "loader/stages/chunk_source_loader.h" #include #include #include "absl/log/log.h" #include "loader/chunk_source/rawfile_chunk_source.h" #include "loader/chunk_source/tar_chunk_source.h" #include "loader/data_loader_metrics.h" #include "proto/data_loader_config.pb.h" namespace lczero { namespace training { std::unique_ptr CreateChunkSourceFromFile( const std::filesystem::path& filepath, ChunkSourceLoaderConfig::FrameFormat frame_format) { auto extension = filepath.extension(); try { if (extension == ".gz") { return std::make_unique(filepath, frame_format); } if (extension == ".tar") { return std::make_unique(filepath, frame_format); } } catch (const std::exception& e) { LOG(ERROR) << "Failed to create chunk source for " << filepath << ": " << e.what(); return nullptr; } return nullptr; } ChunkSourceLoader::ChunkSourceLoader(const ChunkSourceLoaderConfig& config) : SingleInputStage(config), SingleOutputStage(config.output()), thread_pool_(config.threads(), ThreadPoolOptions{}), frame_format_(config.frame_format()) { LOG(INFO) << "Initializing ChunkSourceLoader with " << config.threads() << " worker threads"; // Initialize thread contexts but don't start worker threads yet. thread_contexts_.reserve(config.threads()); for (size_t i = 0; i < config.threads(); ++i) { thread_contexts_.push_back(std::make_unique()); } } ChunkSourceLoader::~ChunkSourceLoader() { Stop(); } void ChunkSourceLoader::Start() { LOG(INFO) << "Starting ChunkSourceLoader worker threads."; for (size_t i = 0; i < thread_contexts_.size(); ++i) { thread_pool_.Enqueue([this, i](std::stop_token stop_token) { Worker(stop_token, thread_contexts_[i].get()); }); } } void ChunkSourceLoader::Stop() { if (thread_pool_.stop_token().stop_requested()) return; LOG(INFO) << "Stopping ChunkSourceLoader."; thread_pool_.Shutdown(); output_queue()->Close(); LOG(INFO) << "ChunkSourceLoader stopped."; } void ChunkSourceLoader::Worker(std::stop_token stop_token, ThreadContext* context) { auto producer = output_queue()->CreateProducer(); LOG(INFO) << "ChunkSourceLoader worker@" << static_cast(context) << " started."; try { while (true) { auto file = [&]() { LoadMetricPauser pauser(context->load_metric_updater); return input_queue()->Get(stop_token); }(); if (file.message_type == FilePathProvider::MessageType::kInitialScanComplete) { LOG(INFO) << "ChunkSourceLoader received initial scan completion marker."; bool should_forward; { absl::MutexLock lock(&phase_mutex_); sentinel_received_ = true; should_forward = (pre_sentinel_work_count_ == 0); } if (should_forward) { LOG(INFO) << "ChunkSourceLoader forwarding initial scan completion " "marker."; producer.Put({.source = nullptr, .message_type = file.message_type}, stop_token); } continue; } // Track pre-sentinel work. bool is_pre_sentinel; { absl::MutexLock lock(&phase_mutex_); is_pre_sentinel = !sentinel_received_; if (is_pre_sentinel) pre_sentinel_work_count_++; } // Create ChunkSource from the file. LOG_EVERY_N(INFO, 1000) << "ChunkSourceLoader preparing chunk source for " << file.filepath; auto source = CreateChunkSourceFromFile(file.filepath, frame_format_); if (source) { { absl::MutexLock lock(&last_chunk_key_mutex_); last_chunk_key_ = source->GetChunkSortKey(); } ChunkSourceWithPhase output{.source = std::move(source), .message_type = file.message_type}; LoadMetricPauser pauser(context->load_metric_updater); producer.Put(std::move(output), stop_token); } else { LOG_EVERY_N(INFO, 100) << "ChunkSourceLoader skipping unsupported file: " << file.filepath; skipped_files_count_++; } // Complete pre-sentinel work tracking. if (is_pre_sentinel) { absl::MutexLock lock(&phase_mutex_); if (--pre_sentinel_work_count_ == 0 && sentinel_received_) { LOG(INFO) << "ChunkSourceLoader forwarding initial scan completion " "marker after all pre-sentinel work completed."; producer.Put( {.source = nullptr, .message_type = FilePathProvider::MessageType::kInitialScanComplete}, stop_token); } } } } catch (const QueueClosedException&) { LOG(INFO) << "ChunkSourceLoader worker@" << static_cast(context) << " stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "ChunkSourceLoader worker@" << static_cast(context) << " stopping, request cancelled."; } catch (const std::exception& e) { LOG(ERROR) << "ChunkSourceLoader worker@" << static_cast(context) << " exiting due to exception: " << e.what(); throw; } LOG(INFO) << "ChunkSourceLoader worker@" << static_cast(context) << " exiting loop."; } StageMetricProto ChunkSourceLoader::FlushMetrics() { StageMetricProto stage_metric; LoadMetricProto aggregated_load; aggregated_load.set_name("load"); for (const auto& context : thread_contexts_) { UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(aggregated_load); auto* skipped_metric = stage_metric.add_count_metrics(); skipped_metric->set_name("skipped_files"); skipped_metric->set_count(skipped_files_count_.exchange(0)); // Get the last chunk key. { absl::MutexLock lock(&last_chunk_key_mutex_); if (!last_chunk_key_.empty()) { stage_metric.set_last_chunk_key(last_chunk_key_); } } *stage_metric.add_queue_metrics() = MetricsFromQueue("output", *output_queue()); return stage_metric; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_source_loader.h ================================================ #pragma once #include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "loader/chunk_source/chunk_source.h" #include "loader/stages/file_path_provider.h" #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/load_metric.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Creates a ChunkSource based on file extension. Returns RawFileChunkSource for // .gz files, TarChunkSource for .tar files, or nullptr for unsupported types. std::unique_ptr CreateChunkSourceFromFile( const std::filesystem::path& filepath, ChunkSourceLoaderConfig::FrameFormat frame_format); struct ChunkSourceWithPhase { std::unique_ptr source; FilePathProvider::MessageType message_type; }; // Worker pool that converts FilePathProvider output to ChunkSource objects. // Takes FilePathProvider::File as input and outputs ChunkSourceWithPhase. class ChunkSourceLoader : public SingleInputStage, public SingleOutputStage { public: using InputType = FilePathProvider::File; using OutputType = ChunkSourceWithPhase; explicit ChunkSourceLoader(const ChunkSourceLoaderConfig& config); ~ChunkSourceLoader(); void Start() override; void Stop() override; StageMetricProto FlushMetrics() override; private: struct ThreadContext { LoadMetricUpdater load_metric_updater; }; void Worker(std::stop_token stop_token, ThreadContext* context); ThreadPool thread_pool_; std::vector> thread_contexts_; std::atomic skipped_files_count_{0}; absl::Mutex last_chunk_key_mutex_; std::string last_chunk_key_; ChunkSourceLoaderConfig::FrameFormat frame_format_; // Synchronization for sentinel barrier. absl::Mutex phase_mutex_; int pre_sentinel_work_count_ ABSL_GUARDED_BY(phase_mutex_) = 0; bool sentinel_received_ ABSL_GUARDED_BY(phase_mutex_) = false; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_source_loader_test.cc ================================================ #include "loader/stages/chunk_source_loader.h" #include #include #include "loader/stages/file_path_provider.h" #include "utils/queue.h" namespace lczero { namespace training { namespace { template class PassthroughStage : public Stage { public: explicit PassthroughStage(Queue* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return StageMetricProto(); } QueueBase* GetOutput(std::string_view name = "") override { (void)name; return queue_; } void SetInputs(absl::Span inputs) override { if (!inputs.empty()) { throw std::runtime_error("PassthroughStage expects no inputs"); } } private: Queue* queue_; }; } // namespace TEST(ChunkSourceLoaderTest, ProcessesFiles) { Queue input_queue(10); ChunkSourceLoaderConfig config; config.set_threads(1); config.mutable_output()->set_queue_capacity(10); ChunkSourceLoader feed(config); feed.SetInputs({&input_queue}); feed.Start(); { auto producer = input_queue.CreateProducer(); // Add a file with unsupported extension (should not create ChunkSource) producer.Put(FilePathProvider::File{ .filepath = std::filesystem::path("/test.txt"), // unsupported extension .message_type = FilePathProvider::MessageType::kFile}); } // Producer destroyed here, closing input queue // Try to get output - there should be no valid ChunkSources for unsupported // files try { while (true) { auto output = feed.output_queue()->Get(); // If we get output, it means a ChunkSource was created, which shouldn't // happen for unsupported files FAIL() << "Expected no output for unsupported file extension"; } } catch (const QueueClosedException&) { // Expected: queue should be closed when input is done and no output // produced SUCCEED(); } } TEST(ChunkSourceLoaderTest, HandlesPhases) { Queue input_queue(10); ChunkSourceLoaderConfig config; config.set_threads(1); config.mutable_output()->set_queue_capacity(10); ChunkSourceLoader feed(config); feed.SetInputs({&input_queue}); feed.Start(); { auto producer = input_queue.CreateProducer(); // Test different phases - all should be passed through even if no // ChunkSource is created producer.Put(FilePathProvider::File{ .filepath = std::filesystem::path("/test1.gz"), .message_type = FilePathProvider::MessageType::kFile}); producer.Put(FilePathProvider::File{ .filepath = std::filesystem::path("/test2.gz"), .message_type = FilePathProvider::MessageType::kFile}); } // Producer destroyed here, closing input queue // Queue should eventually close when input is done try { while (true) { feed.output_queue()->Get(); } } catch (const QueueClosedException&) { SUCCEED(); } } TEST(ChunkSourceLoaderTest, PassesThroughInitialScanComplete) { Queue input_queue(10); ChunkSourceLoaderConfig config; config.set_threads(1); config.mutable_output()->set_queue_capacity(10); ChunkSourceLoader feed(config); feed.SetInputs({&input_queue}); feed.Start(); { auto producer = input_queue.CreateProducer(); producer.Put(FilePathProvider::File{ .filepath = std::filesystem::path(""), .message_type = FilePathProvider::MessageType::kInitialScanComplete}); } // Producer destroyed here, closing input queue // Should get kInitialScanComplete in output with null ChunkSource auto output = feed.output_queue()->Get(); EXPECT_EQ(output.message_type, FilePathProvider::MessageType::kInitialScanComplete); EXPECT_EQ(output.source, nullptr); // Queue should be closed after the single message try { feed.output_queue()->Get(); FAIL() << "Expected queue to be closed"; } catch (const QueueClosedException&) { SUCCEED(); } } TEST(ChunkSourceLoaderTest, SentinelBarrierWithMultipleThreads) { Queue input_queue(100); ChunkSourceLoaderConfig config; config.set_threads(4); config.mutable_output()->set_queue_capacity(100); ChunkSourceLoader feed(config); feed.SetInputs({&input_queue}); feed.Start(); { auto producer = input_queue.CreateProducer(); // Add files that will be processed before sentinel. for (int i = 0; i < 20; ++i) { producer.Put(FilePathProvider::File{ .filepath = std::filesystem::path("/test" + std::to_string(i) + ".txt"), .message_type = FilePathProvider::MessageType::kFile}); } // Add sentinel. producer.Put(FilePathProvider::File{ .filepath = std::filesystem::path(""), .message_type = FilePathProvider::MessageType::kInitialScanComplete}); // Add files that arrive after sentinel. for (int i = 20; i < 30; ++i) { producer.Put(FilePathProvider::File{ .filepath = std::filesystem::path("/test" + std::to_string(i) + ".txt"), .message_type = FilePathProvider::MessageType::kFile}); } } // Producer destroyed here, closing input queue // Read all outputs and verify sentinel comes after all pre-sentinel files. int files_before_sentinel = 0; int files_after_sentinel = 0; bool sentinel_seen = false; try { while (true) { auto output = feed.output_queue()->Get(); if (output.message_type == FilePathProvider::MessageType::kInitialScanComplete) { EXPECT_FALSE(sentinel_seen) << "Sentinel should appear exactly once"; sentinel_seen = true; } else { if (sentinel_seen) { files_after_sentinel++; } else { files_before_sentinel++; } } } } catch (const QueueClosedException&) { } // Verify sentinel was seen. EXPECT_TRUE(sentinel_seen); // All 20 pre-sentinel files should be before sentinel (unsupported, so 0). EXPECT_EQ(files_before_sentinel, 0); // All 10 post-sentinel files should be after sentinel (unsupported, so 0). EXPECT_EQ(files_after_sentinel, 0); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_source_splitter.cc ================================================ #include "loader/stages/chunk_source_splitter.h" #include #include #include #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "loader/data_loader_metrics.h" namespace lczero { namespace training { ChunkSourceSplitter::ChunkSourceSplitter( const ChunkSourceSplitterConfig& config) : SingleInputStage(config) { if (config.output().empty()) { throw std::runtime_error("ChunkSourceSplitter requires at least 1 output."); } // Validate parallel arrays have same size. if (config.output_size() != config.weight_size()) { throw std::runtime_error(absl::StrCat( "ChunkSourceSplitter output and weight arrays must have same size: ", config.output_size(), " vs ", config.weight_size())); } // Create output queues from parallel arrays. outputs_.reserve(config.output_size()); for (size_t i = 0; i < static_cast(config.output_size()); ++i) { const auto& queue_cfg = config.output(static_cast(i)); const uint64_t weight = i < static_cast(config.weight_size()) ? config.weight(static_cast(i)) : 1; if (absl::c_any_of(outputs_, [&](const auto& existing_out) { return existing_out->name == queue_cfg.name(); })) { throw std::runtime_error(std::string(absl::StrCat( "Duplicate output name in ChunkSourceSplitter: ", queue_cfg.name()))); } auto* out = outputs_ .emplace_back(std::make_unique( queue_cfg.name(), weight, queue_cfg.queue_capacity(), ToOverflowBehavior(queue_cfg.overflow_behavior()))) .get(); LOG(INFO) << "ChunkSourceSplitter configured output '" << out->name << "' weight=" << out->weight << " capacity=" << queue_cfg.queue_capacity(); } // Precompute cumulative weights for fast assignment. cumulative_.resize(outputs_.size()); std::transform_inclusive_scan( outputs_.begin(), outputs_.end(), cumulative_.begin(), std::plus{}, [](const std::unique_ptr& out) { return out->weight; }); // Validate total weight is positive. if (cumulative_.back() == 0) { throw std::runtime_error( "ChunkSourceSplitter requires at least one output with positive " "weight."); } } ChunkSourceSplitter::~ChunkSourceSplitter() { Stop(); } void ChunkSourceSplitter::Start() { LOG(INFO) << "Starting ChunkSourceSplitter worker."; thread_pool_.Enqueue( [this](std::stop_token stop_token) { Worker(stop_token); }); } void ChunkSourceSplitter::Stop() { if (thread_pool_.stop_token().stop_requested()) return; LOG(INFO) << "Stopping ChunkSourceSplitter."; thread_pool_.Shutdown(); for (auto& out : outputs_) out->queue.Close(); } QueueBase* ChunkSourceSplitter::GetOutput(std::string_view name) { auto iter = absl::c_find_if( outputs_, [&](const auto& out) { return out->name == name; }); if (iter == outputs_.end()) { throw std::runtime_error( absl::StrCat("Unknown output '", name, "' for ChunkSourceSplitter.")); } return &(*iter)->queue; } StageMetricProto ChunkSourceSplitter::FlushMetrics() { StageMetricProto metric; for (auto& out : outputs_) { *metric.add_queue_metrics() = MetricsFromQueue(out->name, out->queue); } return metric; } void ChunkSourceSplitter::Worker(std::stop_token stop_token) { // Create producers for each output in this thread. std::vector::Producer> producers; producers.reserve(outputs_.size()); for (auto& out : outputs_) { producers.emplace_back(out->queue.CreateProducer()); } try { while (true) { InputType item = input_queue()->Get(stop_token); if (item.message_type == FilePathProvider::MessageType::kInitialScanComplete) { // Broadcast to all outputs. for (auto& prod : producers) { prod.Put( OutputType{.source = nullptr, .message_type = item.message_type}, stop_token); } continue; } // Share ownership of the ChunkSource with any produced views. std::shared_ptr shared_source(std::move(item.source)); auto per_output_indices = BuildAssignments(shared_source); // Emit only non-empty views, preserving original message type (kFile). for (size_t i = 0; i < outputs_.size(); ++i) { if (per_output_indices[i].empty()) continue; auto view = std::make_unique( shared_source, std::move(per_output_indices[i])); producers[i].Put( OutputType{.source = std::move(view), .message_type = FilePathProvider::MessageType::kFile}, stop_token); } } } catch (const QueueClosedException&) { // Input queue closed — producers will close queues automatically when // destroyed if this thread holds the last producer. LOG(INFO) << "ChunkSourceSplitter worker exiting: input closed."; } } std::vector> ChunkSourceSplitter::BuildAssignments( const std::shared_ptr& source) { const std::string sort_key = source->GetChunkSortKey(); const size_t n = source->GetChunkCount(); // Prepare result containers with a rough reservation. std::vector> indices(outputs_.size()); for (size_t i = 0; i < n; ++i) { const uint64_t h = static_cast(absl::Hash>{}( std::make_pair(sort_key, i))); const uint64_t r = h % cumulative_.back(); // Find the output where cumulative[j-1] <= r < cumulative[j]. const auto it = std::upper_bound(cumulative_.begin(), cumulative_.end(), r); const size_t idx = it - cumulative_.begin(); indices[idx].push_back(static_cast(i)); } return indices; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_source_splitter.h ================================================ #pragma once #include #include #include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/hash/hash.h" #include "absl/log/log.h" #include "loader/chunk_source/chunk_source.h" #include "loader/chunk_source/chunk_source_view.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Splits an incoming ChunkSource into several ChunkSourceViews based on a // deterministic hash of (sort_key, index). Emits to multiple named outputs. class ChunkSourceSplitter : public SingleInputStage { public: using InputType = ChunkSourceWithPhase; using OutputType = ChunkSourceWithPhase; explicit ChunkSourceSplitter(const ChunkSourceSplitterConfig& config); ~ChunkSourceSplitter(); void Start() override; void Stop() override; StageMetricProto FlushMetrics() override; QueueBase* GetOutput(std::string_view name = "") override; private: struct Output { std::string name; uint64_t weight; Queue queue; Output(std::string_view name, uint64_t weight, size_t capacity, OverflowBehavior overflow) : name(name), weight(weight), queue(capacity, overflow) {} }; void Worker(std::stop_token stop_token); // Builds per-output indices given a source; uses absl::Hash on // (sort_key, index) and weights to assign indices. std::vector> BuildAssignments( const std::shared_ptr& source); std::vector> outputs_; std::vector cumulative_; ThreadPool thread_pool_{1}; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_source_splitter_test.cc ================================================ #include "loader/stages/chunk_source_splitter.h" #include #include #include #include #include "absl/hash/hash.h" #include "gtest/gtest.h" #include "loader/chunk_source/chunk_source.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" #include "utils/queue.h" namespace lczero { namespace training { namespace { // Simple fixed-count chunk source for testing. class FixedCountChunkSource : public ChunkSource { public: FixedCountChunkSource(std::string sort_key, size_t count) : key_(std::move(sort_key)), count_(count) {} private: std::string GetChunkSortKey() const override { return key_; } size_t GetChunkCount() const override { return count_; } std::optional> GetChunkData(size_t) override { return std::vector{FrameType{}}; } std::string key_; size_t count_; }; template class PassthroughStage : public Stage { public: explicit PassthroughStage(Queue* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return StageMetricProto(); } QueueBase* GetOutput(std::string_view) override { return queue_; } void SetInputs(absl::Span inputs) override { if (!inputs.empty()) { throw std::runtime_error("PassthroughStage expects no inputs"); } } private: Queue* queue_; }; } // namespace TEST(ChunkSourceSplitterTest, SplitsByHashAndWeight) { // Upstream queue. auto input_queue = std::make_unique>(8); // Configure splitter with two outputs A:1, B:2. ChunkSourceSplitterConfig cfg; auto* outA = cfg.add_output(); outA->set_name("A"); outA->set_queue_capacity(8); cfg.add_weight(1); auto* outB = cfg.add_output(); outB->set_name("B"); outB->set_queue_capacity(8); cfg.add_weight(2); ChunkSourceSplitter splitter(cfg); splitter.SetInputs({input_queue.get()}); splitter.Start(); // Send a source with known key and count. const std::string key = "skey"; const size_t count = 100; ChunkSourceWithPhase item; item.source = std::make_unique(key, count); item.message_type = FilePathProvider::MessageType::kFile; auto producer = input_queue->CreateProducer(); producer.Put(std::move(item)); producer.Close(); // Compute expected assignment counts using the same hash/weights. const uint64_t total_weight = 3; uint64_t cumA = 1; // [0] size_t expectedA = 0; size_t expectedB = 0; for (size_t i = 0; i < count; ++i) { const uint64_t h = static_cast( absl::Hash>{}(std::make_pair(key, i))); const uint64_t r = h % total_weight; if (r < cumA) ++expectedA; else ++expectedB; } // Read outputs and verify view sizes. auto* qa = dynamic_cast*>(splitter.GetOutput("A")); auto* qb = dynamic_cast*>(splitter.GetOutput("B")); ASSERT_NE(qa, nullptr); ASSERT_NE(qb, nullptr); auto msgA = qa->Get(); auto msgB = qb->Get(); ASSERT_NE(msgA.source, nullptr); ASSERT_NE(msgB.source, nullptr); EXPECT_EQ(msgA.source->GetChunkCount(), expectedA); EXPECT_EQ(msgB.source->GetChunkCount(), expectedB); splitter.Stop(); } TEST(ChunkSourceSplitterTest, BroadcastsInitialScanComplete) { auto input_queue = std::make_unique>(4); ChunkSourceSplitterConfig cfg; auto* outA = cfg.add_output(); outA->set_name("A"); cfg.add_weight(1); auto* outB = cfg.add_output(); outB->set_name("B"); cfg.add_weight(1); ChunkSourceSplitter splitter(cfg); splitter.SetInputs({input_queue.get()}); splitter.Start(); ChunkSourceWithPhase marker; marker.source = nullptr; marker.message_type = FilePathProvider::MessageType::kInitialScanComplete; auto producer = input_queue->CreateProducer(); producer.Put(std::move(marker)); producer.Close(); auto* qa = dynamic_cast*>(splitter.GetOutput("A")); auto* qb = dynamic_cast*>(splitter.GetOutput("B")); auto m1 = qa->Get(); auto m2 = qb->Get(); EXPECT_EQ(m1.message_type, FilePathProvider::MessageType::kInitialScanComplete); EXPECT_EQ(m2.message_type, FilePathProvider::MessageType::kInitialScanComplete); EXPECT_EQ(m1.source, nullptr); EXPECT_EQ(m2.source, nullptr); splitter.Stop(); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_unpacker.cc ================================================ #include "loader/stages/chunk_unpacker.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include "absl/numeric/int128.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" #include "loader/data_loader_metrics.h" #include "loader/stages/position_sampling.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" namespace lczero { namespace training { // Deterministically partitions `n` positions into disjoint subsets of size // `~p*n`, returning the subset for a given `iteration`. While each selection // individually behaves like a Bernoulli sample with probability `p`, the // samples are correlated to ensure all positions are selected exactly once over // `1/p` iterations. To sample disjoint subsets from the same set of positions, // `gen` must be seeded identically for each call with a different `iteration`. std::vector PickSampledPositions(int32_t n, double p, int32_t iteration, absl::BitGen& gen) { assert(p > 0.0 && p <= 1.0); double carried_prob = p; std::vector result; absl::flat_hash_set skip_next_round; while (true) { int32_t num_this_round = (1.0 - carried_prob) / p + 1; double last_partial_prob = 1 - (carried_prob + (num_this_round - 1) * p); const bool return_this_round = iteration < num_this_round; absl::flat_hash_set skip_this_round(skip_next_round); skip_next_round.clear(); for (int32_t i = 0; i < n; ++i) { if (skip_this_round.contains(i)) continue; const double toss = absl::Uniform(gen, 0.0, 1.0); const int32_t value = (toss - carried_prob) / p + 1; if (value == iteration) result.push_back(static_cast(i)); if (value >= num_this_round) { skip_next_round.insert(static_cast(i)); } } if (return_this_round) return result; iteration -= num_this_round; carried_prob = p - last_partial_prob; } } // Samples `k` indices from an infinite, deterministically-generated sequence, // after an initial `skip`. The sequence is formed by repeatedly shuffling // `[0..n-1]` and probabilistically selecting each index. Determinism is // achieved by seeding a new PRNG for each shuffled block from a stable // `root_seed` and the block's index. std::vector SampleProbabilisticSequence( uint64_t k, uint64_t skip, std::span probabilities, absl::BitGen& gen) { const size_t n = probabilities.size(); if (n == 0 || k == 0) return {}; uint64_t skipped_so_far = 0; std::vector v(n); std::iota(v.begin(), v.end(), 0u); std::vector result; result.reserve(k); while (true) { std::shuffle(v.begin(), v.end(), gen); for (const uint32_t candidate : v) { if (!absl::Bernoulli(gen, probabilities[candidate])) continue; if (skipped_so_far < skip) { skipped_so_far++; } else { result.push_back(candidate); if (result.size() == k) return result; } } } } namespace { uint32_t GenerateRunSeed() { absl::BitGen gen(absl::MakeSeedSeq()); return absl::Uniform(gen); } } // namespace ChunkUnpacker::ChunkUnpacker(const ChunkUnpackerConfig& config) : SingleInputStage(config), config_(config), run_seed_(GenerateRunSeed()), primary_output_queue_( config.output().queue_capacity(), ToOverflowBehavior(config.output().overflow_behavior())), thread_pool_(config.threads(), ThreadPoolOptions{}) { const bool has_rate = config.has_position_sampling_rate(); const bool has_count = config.has_position_count(); const bool has_prefetch_count = config.has_prefetch_count(); const bool has_prefetch_output = config.has_prefetch_output(); CHECK(has_prefetch_count == has_prefetch_output) << "prefetch_count and prefetch_output must both be set or both unset."; if (has_prefetch_count) { CHECK(has_count) << "position_count must be set when using prefetch mode."; CHECK(!has_rate) << "position_sampling_rate cannot be used in prefetch mode."; CHECK(config.position_count() == 1) << "position_count must equal 1 in prefetch mode, got " << config.position_count(); prefetch_output_queue_.emplace( config.prefetch_output().queue_capacity(), ToOverflowBehavior(config.prefetch_output().overflow_behavior())); if (config.output().name() == config.prefetch_output().name()) { throw std::runtime_error( absl::StrCat("ChunkUnpacker output names must be different, got: '", config.output().name(), "'")); } } else { CHECK(has_rate != has_count) << "Exactly one of position_sampling_rate or position_count must be " "set."; } LOG(INFO) << "Initializing ChunkUnpacker with " << config.threads() << " worker threads"; // Initialize thread contexts but don't start worker threads yet. thread_contexts_.reserve(config.threads()); for (size_t i = 0; i < config.threads(); ++i) { thread_contexts_.push_back(std::make_unique()); } } ChunkUnpacker::~ChunkUnpacker() { Stop(); } void ChunkUnpacker::Start() { LOG(INFO) << "Starting ChunkUnpacker worker threads."; for (size_t i = 0; i < thread_contexts_.size(); ++i) { thread_pool_.Enqueue([this, i](std::stop_token stop_token) { Worker(stop_token, thread_contexts_[i].get()); }); } } void ChunkUnpacker::Stop() { if (thread_pool_.stop_token().stop_requested()) return; LOG(INFO) << "Stopping ChunkUnpacker."; thread_pool_.Shutdown(); primary_output_queue_.Close(); if (prefetch_output_queue_) prefetch_output_queue_->Close(); LOG(INFO) << "ChunkUnpacker stopped."; } QueueBase* ChunkUnpacker::GetOutput(std::string_view name) { if (name == config_.output().name()) return &primary_output_queue_; if (config_.has_prefetch_output() && name == config_.prefetch_output().name()) { return &*prefetch_output_queue_; } std::string available = absl::StrCat("'", config_.output().name(), "'"); if (config_.has_prefetch_output()) { absl::StrAppend(&available, ", '", config_.prefetch_output().name(), "'"); } throw std::runtime_error(absl::StrCat("ChunkUnpacker unknown output '", name, "'. Available outputs: ", available)); } namespace { std::vector FramesToProbabilities(std::span frames, const PositionSamplingConfig& config) { std::vector probabilities; probabilities.reserve(frames.size()); absl::c_transform(frames, std::back_inserter(probabilities), [&](const FrameType& frame) { return ComputePositionSamplingWeight(frame, config); }); const float max_prob = *absl::c_max_element(probabilities); if (max_prob > 0.0f) { absl::c_transform(probabilities, probabilities.begin(), [max_prob](float p) { return p / max_prob; }); } else { absl::c_fill(probabilities, 1.0f); } return probabilities; } } // namespace void ChunkUnpacker::Worker(std::stop_token stop_token, ThreadContext* context) { // Create a local producer for this worker thread. auto primary_producer = primary_output_queue_.CreateProducer(); std::optionalCreateProducer())> prefetch_producer; if (prefetch_output_queue_.has_value()) { prefetch_producer.emplace(prefetch_output_queue_->CreateProducer()); } try { while (true) { auto chunk = [&]() { LoadMetricPauser pauser(context->load_metric_updater); return input_queue()->Get(stop_token); }(); absl::BitGen gen( std::seed_seq{run_seed_, static_cast(chunk.global_index)}); std::vector positions; if (config_.has_position_sampling_rate()) { positions = PickSampledPositions( static_cast(chunk.frames.size()), config_.position_sampling_rate(), chunk.use_count, gen); } else { auto probabilities = FramesToProbabilities(chunk.frames, config_.position_sampling()); positions = SampleProbabilisticSequence( config_.position_count() + config_.prefetch_count(), config_.position_count() * chunk.use_count, probabilities, gen); } if (config_.has_prefetch_count()) { // Prefetch mode: output first position to primary, rest to prefetch. if (!positions.empty()) { LoadMetricPauser pauser(context->load_metric_updater); primary_producer.Put(std::move(chunk.frames[positions[0]]), stop_token); } if (positions.size() > 1) { CacheRequest cache_request; cache_request.global_index = chunk.global_index; cache_request.next_use = chunk.use_count + 1; cache_request.items.reserve(positions.size() - 1); for (size_t i = 1; i < positions.size(); ++i) { cache_request.items.push_back(chunk.frames[positions[i]]); } LoadMetricPauser pauser(context->load_metric_updater); prefetch_producer->Put(std::move(cache_request), stop_token); } } else { // Normal mode: output all positions to primary. for (uint32_t pos : positions) { LoadMetricPauser pauser(context->load_metric_updater); primary_producer.Put(std::move(chunk.frames[pos]), stop_token); } } } } catch (const QueueClosedException&) { LOG(INFO) << "ChunkUnpacker worker stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "ChunkUnpacker worker stopping, request cancelled."; } } StageMetricProto ChunkUnpacker::FlushMetrics() { StageMetricProto stage_metric; LoadMetricProto aggregated_load; aggregated_load.set_name("load"); for (const auto& context : thread_contexts_) { UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(aggregated_load); *stage_metric.add_queue_metrics() = MetricsFromQueue(config_.output().name(), primary_output_queue_); if (prefetch_output_queue_.has_value()) { *stage_metric.add_queue_metrics() = MetricsFromQueue( config_.prefetch_output().name(), *prefetch_output_queue_); } return stage_metric; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_unpacker.h ================================================ // ABOUTME: Stage that unpacks chunks into FrameType frames. // ABOUTME: Converts stream of std::string chunks to FrameType stream. #pragma once #include #include #include #include #include #include "absl/random/random.h" #include "absl/types/optional.h" #include "loader/data_loader_metrics.h" #include "loader/frame_type.h" #include "loader/stages/stage.h" #include "loader/stages/training_chunk.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Worker pool that unpacks chunks into frames. // Takes parsed TrainingChunk objects as input and outputs individual // FrameType frames. class ChunkUnpacker : public SingleInputStage { public: using InputType = TrainingChunk; explicit ChunkUnpacker(const ChunkUnpackerConfig& config); ~ChunkUnpacker(); void Start() override; void Stop() override; QueueBase* GetOutput(std::string_view name) override; StageMetricProto FlushMetrics() override; Queue* output_queue() { return &primary_output_queue_; } private: struct ThreadContext { LoadMetricUpdater load_metric_updater; }; void Worker(std::stop_token stop_token, ThreadContext* context); const ChunkUnpackerConfig config_; const uint32_t run_seed_; Queue primary_output_queue_; std::optional> prefetch_output_queue_; // thread_contexts_ must be declared before thread_pool_ to ensure // thread_pool_ is destroyed first (stopping threads before contexts). std::vector> thread_contexts_; ThreadPool thread_pool_; }; std::vector PickSampledPositions(int32_t n, double p, int32_t iteration, absl::BitGen& gen); } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/chunk_unpacker_test.cc ================================================ #include "loader/stages/chunk_unpacker.h" #include #include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/random/random.h" #include "absl/random/seed_sequences.h" #include "gtest/gtest.h" #include "libs/lc0/src/trainingdata/trainingdata_v6.h" #include "loader/stages/training_chunk.h" #include "proto/data_loader_config.pb.h" #include "utils/queue.h" namespace lczero { namespace training { namespace { template class PassthroughStage : public Stage { public: explicit PassthroughStage(Queue* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return StageMetricProto(); } QueueBase* GetOutput(std::string_view name = "") override { (void)name; return queue_; } void SetInputs(absl::Span inputs) override { if (!inputs.empty()) { throw std::runtime_error("PassthroughStage expects no inputs"); } } private: Queue* queue_; }; } // namespace class ChunkUnpackerTest : public ::testing::Test { protected: void SetUp() override { input_queue_ = std::make_unique>(10); config_.set_threads(1); config_.mutable_output()->set_queue_capacity(10); config_.set_position_sampling_rate(1.0f); } FrameType CreateTestFrame(uint32_t version) { FrameType frame{}; frame.version = version; frame.input_format = 3; frame.root_q = 0.5f; return frame; } TrainingChunk MakeChunk(std::vector frames, std::string sort_key = "source", size_t index = 0, uint32_t use = 0) { TrainingChunk chunk; chunk.sort_key = std::move(sort_key); chunk.index_within_sort_key = index; chunk.use_count = use; chunk.frames = std::move(frames); return chunk; } std::unique_ptr> input_queue_; ChunkUnpackerConfig config_; }; TEST_F(ChunkUnpackerTest, UnpacksSingleFrame) { ChunkUnpacker unpacker(config_); unpacker.SetInputs({input_queue_.get()}); unpacker.Start(); FrameType test_frame = CreateTestFrame(6); auto producer = input_queue_->CreateProducer(); producer.Put(MakeChunk({test_frame})); producer.Close(); auto output_frame = unpacker.output_queue()->Get(); EXPECT_EQ(output_frame.version, 6); EXPECT_EQ(output_frame.input_format, 3); EXPECT_EQ(output_frame.root_q, 0.5f); } TEST_F(ChunkUnpackerTest, UnpacksMultipleFrames) { ChunkUnpacker unpacker(config_); unpacker.SetInputs({input_queue_.get()}); unpacker.Start(); std::vector test_frames = {CreateTestFrame(6), CreateTestFrame(7), CreateTestFrame(8)}; auto producer = input_queue_->CreateProducer(); producer.Put(MakeChunk(test_frames)); producer.Close(); std::vector actual_versions; actual_versions.reserve(test_frames.size()); for (size_t i = 0; i < test_frames.size(); ++i) { auto output_frame = unpacker.output_queue()->Get(); actual_versions.push_back(output_frame.version); EXPECT_EQ(output_frame.input_format, 3); EXPECT_EQ(output_frame.root_q, 0.5f); } std::vector expected_versions; expected_versions.reserve(test_frames.size()); for (const auto& frame : test_frames) { expected_versions.push_back(frame.version); } absl::c_sort(actual_versions); absl::c_sort(expected_versions); EXPECT_EQ(actual_versions, expected_versions); } TEST_F(ChunkUnpackerTest, UnpacksMultipleChunks) { ChunkUnpacker unpacker(config_); unpacker.SetInputs({input_queue_.get()}); unpacker.Start(); auto producer = input_queue_->CreateProducer(); // Send first chunk with 2 frames std::vector chunk1_frames = {CreateTestFrame(10), CreateTestFrame(11)}; producer.Put(MakeChunk(chunk1_frames, "source", 0)); // Send second chunk with 1 frame std::vector chunk2_frames = {CreateTestFrame(12)}; producer.Put(MakeChunk(chunk2_frames, "source", 1)); producer.Close(); // Verify all frames are output std::vector expected_versions = {10, 11, 12}; std::vector actual_versions; actual_versions.reserve(expected_versions.size()); for (size_t i = 0; i < expected_versions.size(); ++i) { auto output_frame = unpacker.output_queue()->Get(); actual_versions.push_back(output_frame.version); EXPECT_EQ(output_frame.input_format, 3); EXPECT_EQ(output_frame.root_q, 0.5f); } absl::c_sort(actual_versions); absl::c_sort(expected_versions); EXPECT_EQ(actual_versions, expected_versions); } TEST_F(ChunkUnpackerTest, HandlesEmptyChunk) { ChunkUnpacker unpacker(config_); unpacker.SetInputs({input_queue_.get()}); unpacker.Start(); auto producer = input_queue_->CreateProducer(); TrainingChunk empty_chunk; empty_chunk.sort_key = "source"; empty_chunk.index_within_sort_key = 0; producer.Put(std::move(empty_chunk)); producer.Close(); // Should not produce any output frames, queue should close EXPECT_THROW(unpacker.output_queue()->Get(), QueueClosedException); } TEST_F(ChunkUnpackerTest, HandlesQueueClosure) { ChunkUnpacker unpacker(config_); unpacker.SetInputs({input_queue_.get()}); unpacker.Start(); // Close input queue without sending data input_queue_->Close(); // Output queue should eventually close EXPECT_THROW(unpacker.output_queue()->Get(), QueueClosedException); } TEST(PickSampledPositionsTest, Deterministic) { absl::BitGen gen1(absl::SeedSeq{42}); std::vector result1 = PickSampledPositions(1000, 0.1, 5, gen1); absl::BitGen gen2(absl::SeedSeq{42}); std::vector result2 = PickSampledPositions(1000, 0.1, 5, gen2); EXPECT_EQ(result1, result2); } TEST(PickSampledPositionsTest, FullBucketFirstRound) { absl::BitGen gen(absl::SeedSeq{42}); const uint32_t n = 10000; const double p = 0.1; std::vector result = PickSampledPositions(n, p, 0, gen); // Expect size to be around n*p. EXPECT_NEAR(result.size(), n * p, n * p * 0.25); } TEST(PickSampledPositionsTest, DisjointBuckets) { absl::BitGen gen(absl::SeedSeq{42}); const uint32_t n = 1000; const double p = 0.1; std::vector bucket1 = PickSampledPositions(n, p, 0, gen); absl::c_sort(bucket1); // The generator state is now changed. For the next bucket, we need a fresh // one with the same seed to test the logic for a different iteration. absl::BitGen gen2(absl::SeedSeq{42}); std::vector bucket2 = PickSampledPositions(n, p, 1, gen2); absl::c_sort(bucket2); std::vector intersection; absl::c_set_intersection(bucket1, bucket2, std::back_inserter(intersection)); EXPECT_TRUE(intersection.empty()); } TEST(PickSampledPositionsTest, PartialBucketElementsAreReturned) { absl::BitGen gen(absl::SeedSeq{42}); const uint32_t n = 1000; const double p = 0.8; // remainder 0.2 // In round 1, elements with toss >= 0.8 are for iteration 1. absl::BitGen gen1(absl::SeedSeq{42}); std::vector expected_from_round1; for (uint32_t i = 0; i < n; ++i) { double toss = absl::Uniform(gen1, 0.0, 1.0); if (toss >= 0.8) { expected_from_round1.push_back(i); } } absl::c_sort(expected_from_round1); std::vector result = PickSampledPositions(n, p, 1, gen); absl::c_sort(result); // Check if all elements from round 1 are in the final result. // This will fail with the current implementation because they are discarded. std::vector intersection; absl::c_set_intersection(expected_from_round1, result, std::back_inserter(intersection)); EXPECT_GT(expected_from_round1.size(), 50); // High probability for n=1000 EXPECT_EQ(intersection.size(), expected_from_round1.size()); } TEST(PickSampledPositionsTest, PartialBucketCompletedSize) { absl::BitGen gen(absl::SeedSeq{42}); const uint32_t n = 10000; const double p = 0.8; // remainder 0.2 std::vector result = PickSampledPositions(n, p, 1, gen); // Expect size to be around n*p. // This will fail due to incorrect probability calculation for completion. EXPECT_NEAR(result.size(), n * p, n * p * 0.25); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/file_path_provider.cc ================================================ #include "loader/stages/file_path_provider.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "loader/data_loader_metrics.h" #include "proto/data_loader_config.pb.h" namespace lczero { namespace training { namespace { bool ShouldSkipName(std::string_view name) { return !name.empty() && name.front() == '.'; } bool ShouldSkipPathEntry(const FilePathProvider::Path& path) { return ShouldSkipName(path.filename().string()); } } // namespace FilePathProvider::FilePathProvider(const FilePathProviderConfig& config) : SingleOutputStage(config.output()), directory_(config.directory()), producer_(output_queue()->CreateProducer()), load_metric_updater_() { LOG(INFO) << "Initializing FilePathProvider for directory: " << config.directory(); inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK); CHECK_NE(inotify_fd_, -1) << "Failed to initialize inotify: " << strerror(errno); } FilePathProvider::~FilePathProvider() { LOG(INFO) << "FilePathProvider shutting down."; Stop(); if (inotify_fd_ != -1) close(inotify_fd_); LOG(INFO) << "FilePathProvider shutdown complete."; } void FilePathProvider::SetInputs(absl::Span inputs) { if (!inputs.empty()) { throw std::runtime_error( "FilePathProvider expects no inputs, but received " + std::to_string(inputs.size())); } } void FilePathProvider::Start() { LOG(INFO) << "Starting FilePathProvider monitoring thread."; thread_pool_.Enqueue( [this](std::stop_token stop_token) { Worker(stop_token); }); } void FilePathProvider::Stop() { if (stop_source_.stop_requested()) return; LOG(INFO) << "Stopping FilePathProvider."; LOG(INFO) << "Stopping all watches..."; for (const auto& [wd, path] : watch_descriptors_) { inotify_rm_watch(inotify_fd_, wd); } watch_descriptors_.clear(); stop_source_.request_stop(); thread_pool_.Shutdown(); producer_.Close(); } StageMetricProto FilePathProvider::FlushMetrics() { StageMetricProto stage_metric; auto load_metrics = load_metric_updater_.FlushMetrics(); load_metrics.set_name("load"); *stage_metric.add_load_metrics() = std::move(load_metrics); *stage_metric.add_queue_metrics() = MetricsFromQueue("output", *output_queue()); return stage_metric; } void FilePathProvider::AddDirectory(const Path& directory, std::stop_token stop_token) { ScanDirectoryWithWatch(directory, stop_token); LOG(INFO) << "FilePathProvider registered " << directory << "; active watch descriptors: " << watch_descriptors_.size(); // Signal that initial scan is complete LOG(INFO) << "FilePathProvider initial scan complete"; producer_.Put( {{.filepath = Path{}, .message_type = MessageType::kInitialScanComplete}}, stop_token); } void FilePathProvider::ScanDirectoryWithWatch(const Path& directory, std::stop_token stop_token) { // Step 1: Set up watch first int wd = inotify_add_watch(inotify_fd_, directory.c_str(), IN_CLOSE_WRITE | IN_MOVED_TO | IN_CREATE | IN_DELETE | IN_DELETE_SELF | IN_MOVE); CHECK_NE(wd, -1) << "Failed to add inotify watch for " << directory << ": " << strerror(errno); watch_descriptors_[wd] = directory; // Step 2: Scan directory non-recursively, remembering files and subdirs std::vector files; std::vector subdirectories; std::error_code ec; auto iterator = std::filesystem::directory_iterator(directory, ec); CHECK(!ec) << "Failed to iterate directory " << directory << ": " << ec.message(); for (const auto& entry : iterator) { const Path entry_path = entry.path(); if (ShouldSkipPathEntry(entry_path)) continue; if (entry.is_regular_file(ec) && !ec) { files.push_back(entry_path); } else if (entry.is_directory(ec) && !ec) { subdirectories.push_back(entry_path); } } const size_t initial_file_count = files.size(); const size_t subdirectory_count = subdirectories.size(); LOG(INFO) << "FilePathProvider scanned " << directory << " discovering " << initial_file_count << " file(s) and " << subdirectory_count << " subdirectory(ies) before watch reconciliation."; // Send notifications for discovered files constexpr size_t kBatchSize = 10000; std::vector batch; batch.reserve(kBatchSize); auto flush_batch = [&]() { if (batch.empty()) return; producer_.Put(batch, stop_token); batch.clear(); }; for (const auto& filepath : files) { batch.push_back( {.filepath = filepath.string(), .message_type = MessageType::kFile}); if (batch.size() >= kBatchSize) flush_batch(); } if (initial_file_count > 0) { LOG(INFO) << "FilePathProvider enqueued " << initial_file_count << " file(s) from initial scan of " << directory; } // Step 3: Read from watch descriptor, skipping already discovered files ProcessWatchEventsForNewItems(files); // Step 4: Clean the files vector to save memory files.clear(); // Step 5: Recursively call for subdirectories for (const auto& subdir : subdirectories) { if (stop_token.stop_requested()) return; ScanDirectoryWithWatch(subdir, stop_token); } // Flush any remaining files flush_batch(); } void FilePathProvider::ProcessWatchEventsForNewItems( const std::vector& known_files) { // Create a set for fast lookup of already discovered files absl::flat_hash_set known_file_set; for (const auto& file : known_files) { known_file_set.insert(file.string()); } // Process any events that may have occurred during scanning std::array buffer; std::vector new_files; while (true) { ssize_t length = read(inotify_fd_, buffer.data(), buffer.size()); if (length <= 0) break; // No more events to process ssize_t offset = 0; while (offset < length) { const struct inotify_event* event = reinterpret_cast(buffer.data() + offset); const bool skip_entry = event->len > 0 && ShouldSkipName(event->name); // Only process file creation/write events, skip already known files if ((event->mask & (IN_CLOSE_WRITE | IN_MOVED_TO)) != 0 && event->len > 0 && !skip_entry) { const Path directory(watch_descriptors_.at(event->wd)); Path filepath = directory / event->name; std::string filepath_string = filepath.string(); // Only add if we haven't seen this file before if (!known_file_set.contains(filepath_string)) { new_files.push_back({.filepath = std::move(filepath_string), .message_type = MessageType::kFile}); } } offset += sizeof(struct inotify_event) + event->len; } } // Send notifications for any new files discovered through watch events if (!new_files.empty()) { LOG(INFO) << "FilePathProvider observed " << new_files.size() << " new file(s) while reconciling race events."; producer_.Put(new_files); } } void FilePathProvider::AddWatchRecursive(const Path& path) { // Add watch for current directory int wd = inotify_add_watch(inotify_fd_, path.c_str(), IN_CLOSE_WRITE | IN_MOVED_TO | IN_CREATE | IN_DELETE | IN_DELETE_SELF | IN_MOVE); CHECK_NE(wd, -1) << "Failed to add inotify watch for " << path << ": " << strerror(errno); watch_descriptors_[wd] = path; // Recursively add watches for subdirectories std::error_code ec; auto iterator = std::filesystem::directory_iterator(path, ec); CHECK(!ec) << "Failed to iterate directory " << path << ": " << ec.message(); for (const auto& entry : iterator) { const Path entry_path = entry.path(); if (ShouldSkipPathEntry(entry_path)) continue; if (!entry.is_directory(ec) || ec) continue; AddWatchRecursive(entry_path); } } void FilePathProvider::RemoveWatchRecursive(const Path& base) { absl::erase_if(watch_descriptors_, [&](const auto& pair) { const auto& [wd, path] = pair; const auto mismatch_iter = absl::c_mismatch(base, path).first; // If path is not a subdirectory (or equal) of base, skip. if (mismatch_iter != base.end()) return false; inotify_rm_watch(inotify_fd_, wd); return true; }); } void FilePathProvider::Worker(std::stop_token stop_token) { // Perform directory scanning in background thread AddDirectory(directory_, stop_token); int epoll_fd = epoll_create1(EPOLL_CLOEXEC); CHECK_NE(epoll_fd, -1) << "Failed to create epoll fd: " << strerror(errno); absl::Cleanup epoll_cleanup([epoll_fd]() { close(epoll_fd); }); struct epoll_event event; event.events = EPOLLIN; event.data.fd = inotify_fd_; CHECK_EQ(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, inotify_fd_, &event), 0) << "Failed to add inotify fd to epoll: " << strerror(errno); while (!stop_token.stop_requested()) { { LoadMetricPauser pauser(load_metric_updater_); std::this_thread::sleep_for(std::chrono::milliseconds(50)); if (stop_token.stop_requested()) { pauser.DoNotResume(); break; } } struct epoll_event event; int nfds = epoll_wait(epoll_fd, &event, 1, 0); // Non-blocking check CHECK_NE(nfds, -1) << "epoll_wait failed: " << strerror(errno); if (nfds == 0) continue; // No events. do { assert(nfds == 1 && event.data.fd == inotify_fd_); ProcessInotifyEvents(producer_, stop_token); nfds = epoll_wait(epoll_fd, &event, 1, 0); } while (nfds > 0); } } void FilePathProvider::ProcessInotifyEvents(Queue::Producer& producer, std::stop_token stop_token) { constexpr size_t kNotifyBatchSize = 10000; std::vector files; std::array buffer; auto flush_batch = [&]() { if (files.empty()) return; producer.Put(files, stop_token); files.clear(); }; while (true) { ssize_t length = read(inotify_fd_, buffer.data(), buffer.size()); if (length <= 0) break; // No more events to process ssize_t offset = 0; while (offset < length) { const struct inotify_event* event = reinterpret_cast(buffer.data() + offset); auto file = ProcessInotifyEvent(*event, stop_token); if (file) files.push_back(*file); if (files.size() >= kNotifyBatchSize) flush_batch(); offset += sizeof(struct inotify_event) + event->len; } } flush_batch(); // Flush any remaining files in the batch } auto FilePathProvider::ProcessInotifyEvent(const struct inotify_event& event, std::stop_token stop_token) -> std::optional { if (event.mask & IN_IGNORED) return std::nullopt; const Path directory(watch_descriptors_.at(event.wd)); const bool has_name = event.len > 0 && event.name[0] != '\0'; const bool skip_entry = has_name && ShouldSkipName(event.name); const Path filepath = has_name ? directory / event.name : directory; // Handle different event types if ((event.mask & (IN_CLOSE_WRITE | IN_MOVED_TO)) != 0 && has_name && !skip_entry) { // File finished writing or moved into directory return File{.filepath = filepath, .message_type = MessageType::kFile}; } constexpr uint32_t kDirCreateMask = IN_CREATE | IN_ISDIR; constexpr uint32_t kDirDeleteMask = IN_DELETE | IN_ISDIR; if ((event.mask & kDirCreateMask) == kDirCreateMask) { if (!has_name || skip_entry) return std::nullopt; ScanDirectoryWithWatch(filepath, stop_token); } else if ((event.mask & kDirDeleteMask) == kDirDeleteMask) { if (!has_name || skip_entry) return std::nullopt; // Directory deleted - remove all watches for it and subdirectories RemoveWatchRecursive(filepath); } else if (event.mask & IN_DELETE_SELF) { RemoveWatchRecursive(directory); } return std::nullopt; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/file_path_provider.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/load_metric.h" #include "utils/metrics/printer.h" #include "utils/metrics/statistics_metric.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Message types for FilePathProvider output. enum class FilePathProviderMessageType { kFile, // File discovered (initial scan or inotify) kInitialScanComplete // Initial scan is complete (empty filepath) }; // Output type for FilePathProvider. struct FilePathProviderFile { std::filesystem::path filepath; FilePathProviderMessageType message_type; }; // This class watches for new files in a directory (recursively) and notifies // registered observers when new files are either closed after writing or // renamed into. // Uses background thread to monitor the directory. class FilePathProvider : public SingleOutputStage { public: using Path = std::filesystem::path; using MessageType = FilePathProviderMessageType; using File = FilePathProviderFile; explicit FilePathProvider(const FilePathProviderConfig& config); ~FilePathProvider(); // Starts monitoring the directory void Start() override; // Closes the output queue, signaling completion void Stop() override; // Returns current metrics and clears them. StageMetricProto FlushMetrics() override; // FilePathProvider has no inputs. void SetInputs(absl::Span inputs) override; private: // Starts monitoring the directory. void AddDirectory(const Path& directory, std::stop_token stop_token); void Worker(std::stop_token stop_token); void AddWatchRecursive(const Path& path); void RemoveWatchRecursive(const Path& path); void ScanDirectoryWithWatch(const Path& directory, std::stop_token stop_token); void ProcessWatchEventsForNewItems(const std::vector& known_files); void ProcessInotifyEvents(Queue::Producer& producer, std::stop_token stop_token); std::optional ProcessInotifyEvent(const struct inotify_event& event, std::stop_token stop_token); int inotify_fd_; // Watch descriptor to directory path. absl::flat_hash_map watch_descriptors_; Path directory_; // Directory to monitor Queue::Producer producer_; LoadMetricUpdater load_metric_updater_; std::stop_source stop_source_; ThreadPool thread_pool_{1, ThreadPoolOptions{}, stop_source_}; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/file_path_provider_main.cc ================================================ #include #include #include #include #include #include #include #include #include "loader/stages/file_path_provider.h" ABSL_FLAG(std::string, directory, "", "Directory to monitor for files"); int main(int argc, char* argv[]) { absl::ParseCommandLine(argc, argv); absl::InitializeLog(); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); std::string directory = absl::GetFlag(FLAGS_directory); if (directory.empty()) { std::cerr << "Usage: " << argv[0] << " --directory=" << std::endl; return 1; } LOG(INFO) << "Starting to monitor directory: " << directory; lczero::training::FilePathProviderConfig config; config.mutable_output()->set_queue_capacity(16); config.set_directory(directory); lczero::training::FilePathProvider file_path_provider(config); // Consumer thread to read from the queue std::thread consumer_thread([&file_path_provider]() { auto* queue = file_path_provider.output_queue(); try { while (true) { auto file = queue->Get(); const char* type_str = (file.message_type == lczero::training::FilePathProvider::MessageType::kFile) ? "File" : "Initial scan complete"; LOG(INFO) << "File " << type_str << ": " << file.filepath; } } catch (const lczero::QueueClosedException&) { LOG(INFO) << "Queue closed, consumer thread exiting"; } }); LOG(INFO) << "Monitoring for files... Press Enter to exit."; std::cin.get(); // Close the queue and wait for consumer to finish file_path_provider.Stop(); consumer_thread.join(); return 0; } ================================================ FILE: csrc/loader/stages/file_path_provider_test.cc ================================================ #include "loader/stages/file_path_provider.h" #include #include #include #include #include #include #include namespace lczero { namespace training { namespace { FilePathProviderConfig MakeConfig(const std::filesystem::path& directory) { FilePathProviderConfig config; config.mutable_output()->set_queue_capacity(128); config.set_directory(directory.string()); return config; } std::string RelativeTo(const std::filesystem::path& base, const std::filesystem::path& target) { return target.lexically_relative(base).generic_string(); } } // namespace class FilePathProviderTest : public ::testing::Test { protected: void SetUp() override { test_dir_ = std::filesystem::temp_directory_path() / ("file_path_provider_test_" + std::to_string( std::chrono::steady_clock::now().time_since_epoch().count())); std::filesystem::create_directories(test_dir_); } void TearDown() override { if (std::filesystem::exists(test_dir_)) { std::filesystem::remove_all(test_dir_); } } void CreateFile(const std::filesystem::path& path, const std::string& content = "payload") { std::filesystem::create_directories(path.parent_path()); std::ofstream file(path); file << content; } void CreateDirectory(const std::filesystem::path& path) { std::filesystem::create_directories(path); } std::vector DrainInitialScan( Queue* queue) { std::vector files; while (true) { auto message = queue->Get(); if (message.message_type == FilePathProvider::MessageType::kInitialScanComplete) { EXPECT_TRUE(message.filepath.empty()); break; } if (message.message_type != FilePathProvider::MessageType::kFile) { ADD_FAILURE() << "Unexpected message type in initial scan."; continue; } files.push_back(message.filepath); } return files; } FilePathProvider::File AwaitNextFile(Queue* queue) { while (true) { auto message = queue->Get(); if (message.message_type == FilePathProvider::MessageType::kFile) { return message; } if (message.message_type != FilePathProvider::MessageType::kInitialScanComplete) { ADD_FAILURE() << "Unexpected message type while waiting for file notification."; } } } FilePathProviderConfig Config() const { return MakeConfig(test_dir_); } std::filesystem::path test_dir_; }; TEST_F(FilePathProviderTest, ConstructorCreatesQueue) { FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); ASSERT_NE(queue, nullptr); EXPECT_EQ(queue->Capacity(), 128); auto message = queue->Get(); EXPECT_EQ(message.message_type, FilePathProvider::MessageType::kInitialScanComplete); EXPECT_TRUE(message.filepath.empty()); provider.Stop(); } TEST_F(FilePathProviderTest, InitialScanFindsVisibleFiles) { CreateFile(test_dir_ / "file1.txt"); CreateFile(test_dir_ / "file2.txt"); CreateFile(test_dir_ / "sub" / "nested.txt"); FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); auto discovered = DrainInitialScan(queue); std::unordered_set relative_paths; for (const auto& path : discovered) { relative_paths.insert(RelativeTo(test_dir_, path)); } EXPECT_EQ(relative_paths.size(), 3u); EXPECT_TRUE(relative_paths.count("file1.txt")); EXPECT_TRUE(relative_paths.count("file2.txt")); EXPECT_TRUE(relative_paths.count("sub/nested.txt")); provider.Stop(); } TEST_F(FilePathProviderTest, InitialScanSkipsHiddenEntries) { CreateFile(test_dir_ / "visible.txt"); CreateFile(test_dir_ / ".hidden_file"); CreateFile(test_dir_ / ".hidden_dir" / "nested.txt"); CreateFile(test_dir_ / "visible_dir" / "child.txt"); FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); auto discovered = DrainInitialScan(queue); std::unordered_set relative_paths; for (const auto& path : discovered) { relative_paths.insert(RelativeTo(test_dir_, path)); } EXPECT_TRUE(relative_paths.count("visible.txt")); EXPECT_TRUE(relative_paths.count("visible_dir/child.txt")); EXPECT_FALSE(relative_paths.count(".hidden_file")); EXPECT_FALSE(relative_paths.count(".hidden_dir/nested.txt")); provider.Stop(); } TEST_F(FilePathProviderTest, DetectsNewVisibleFile) { FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); DrainInitialScan(queue); CreateFile(test_dir_ / "new_file.txt"); auto message = AwaitNextFile(queue); EXPECT_EQ(RelativeTo(test_dir_, message.filepath), "new_file.txt"); provider.Stop(); } TEST_F(FilePathProviderTest, DetectsFilesInPreExistingSubdirectory) { auto subdir = test_dir_ / "subdir"; CreateDirectory(subdir); FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); DrainInitialScan(queue); CreateFile(subdir / "from_subdir.txt"); auto message = AwaitNextFile(queue); EXPECT_EQ(RelativeTo(test_dir_, message.filepath), "subdir/from_subdir.txt"); provider.Stop(); } TEST_F(FilePathProviderTest, IgnoresHiddenFileEvents) { FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); DrainInitialScan(queue); CreateFile(test_dir_ / ".hidden_event.txt"); CreateFile(test_dir_ / "visible_after_hidden.txt"); auto message = AwaitNextFile(queue); EXPECT_EQ(RelativeTo(test_dir_, message.filepath), "visible_after_hidden.txt"); provider.Stop(); } TEST_F(FilePathProviderTest, SkipsHiddenDirectoryRecursion) { FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); DrainInitialScan(queue); CreateDirectory(test_dir_ / ".hidden_dir"); CreateFile(test_dir_ / ".hidden_dir" / "inner.txt"); CreateFile(test_dir_ / "outer.txt"); auto message = AwaitNextFile(queue); EXPECT_EQ(RelativeTo(test_dir_, message.filepath), "outer.txt"); provider.Stop(); } TEST_F(FilePathProviderTest, HandlesEmptyDirectory) { FilePathProvider provider(Config()); provider.Start(); auto* queue = provider.output_queue(); auto discovered = DrainInitialScan(queue); EXPECT_TRUE(discovered.empty()); provider.Stop(); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/join_stage.cc ================================================ #include "loader/stages/join_stage.h" #include #include "loader/data_loader_metrics.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/load_metric.h" namespace lczero { namespace training { template JoinStage::JoinStage(const JoinPositionsConfig& config) : SingleOutputStage(config.output()) {} template JoinStage::~JoinStage() { Stop(); } template void JoinStage::SetInputs(absl::Span inputs) { input_queues_.clear(); for (QueueBase* base_queue : inputs) { auto* typed_queue = dynamic_cast*>(base_queue); if (!typed_queue) throw std::runtime_error("Input queue type mismatch"); input_queues_.push_back(typed_queue); } } template void JoinStage::Start() { thread_contexts_.clear(); thread_pool_ = std::make_unique(input_queues_.size()); for (size_t i = 0; i < input_queues_.size(); ++i) { thread_contexts_.push_back(std::make_unique()); } for (size_t i = 0; i < input_queues_.size(); ++i) { thread_pool_->Enqueue([this, i](std::stop_token stop_token) { Worker(stop_token, input_queues_[i], thread_contexts_[i].get()); }); } } template void JoinStage::Worker(std::stop_token stop_token, Queue* input_queue, ThreadContext* context) { auto producer = this->output_queue()->CreateProducer(); try { while (true) { auto item = [&]() { LoadMetricPauser pauser(context->load_metric_updater); return input_queue->Get(stop_token); }(); producer.Put(std::move(item), stop_token); } } catch (const QueueClosedException&) { } } template void JoinStage::Stop() { if (!thread_pool_ || thread_pool_->stop_token().stop_requested()) return; LOG(INFO) << "Stopping JoinStage."; thread_pool_->Shutdown(); this->output_queue()->Close(); } template StageMetricProto JoinStage::FlushMetrics() { StageMetricProto metrics; LoadMetricProto aggregated_load; aggregated_load.set_name("load"); for (const auto& context : thread_contexts_) { UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics()); } *metrics.add_load_metrics() = std::move(aggregated_load); *metrics.add_queue_metrics() = MetricsFromQueue("output", *this->output_queue()); return metrics; } // Explicit template instantiation for FrameType. template class JoinStage; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/join_stage.h ================================================ // ABOUTME: Stage that joins multiple input queues into a single output. // ABOUTME: Spawns one thread per input to read and forward items. #pragma once #include #include #include #include #include "loader/data_loader_metrics.h" #include "loader/frame_type.h" #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Template stage that joins multiple input queues into a single output. // Spawns one thread per input to consume and forward items. template class JoinStage : public SingleOutputStage { public: using OutputType = T; explicit JoinStage(const JoinPositionsConfig& config); ~JoinStage(); void Start() override; void Stop() override; StageMetricProto FlushMetrics() override; void SetInputs(absl::Span inputs) override; private: struct ThreadContext { LoadMetricUpdater load_metric_updater; }; void Worker(std::stop_token stop_token, Queue* input_queue, ThreadContext* context); std::vector*> input_queues_; // thread_contexts_ must be declared before thread_pool_ to ensure // thread_pool_ is destroyed first (stopping threads before contexts). std::vector> thread_contexts_; std::unique_ptr thread_pool_; }; using JoinPositions = JoinStage; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/join_stage_test.cc ================================================ #include "loader/stages/join_stage.h" #include #include #include "absl/container/flat_hash_set.h" #include "gtest/gtest.h" #include "libs/lc0/src/trainingdata/trainingdata_v6.h" #include "proto/data_loader_config.pb.h" #include "utils/queue.h" namespace lczero { namespace training { class JoinStageTest : public ::testing::Test { protected: void SetUp() override { config_.mutable_output()->set_queue_capacity(100); } FrameType CreateTestFrame(uint32_t version) { FrameType frame{}; frame.version = version; frame.input_format = 3; frame.root_q = 0.5f; return frame; } JoinPositionsConfig config_; }; TEST_F(JoinStageTest, JoinsTwoInputs) { auto input_queue_1 = std::make_unique>(10); auto input_queue_2 = std::make_unique>(10); JoinPositions join_stage(config_); join_stage.SetInputs({input_queue_1.get(), input_queue_2.get()}); join_stage.Start(); auto producer_1 = input_queue_1->CreateProducer(); auto producer_2 = input_queue_2->CreateProducer(); producer_1.Put(CreateTestFrame(1)); producer_1.Put(CreateTestFrame(2)); producer_2.Put(CreateTestFrame(3)); producer_2.Put(CreateTestFrame(4)); absl::flat_hash_set received_versions; for (int i = 0; i < 4; ++i) { auto frame = join_stage.output_queue()->Get(); received_versions.insert(frame.version); } producer_1.Close(); producer_2.Close(); join_stage.Stop(); EXPECT_EQ(received_versions.size(), 4u); EXPECT_TRUE(received_versions.contains(1)); EXPECT_TRUE(received_versions.contains(2)); EXPECT_TRUE(received_versions.contains(3)); EXPECT_TRUE(received_versions.contains(4)); } TEST_F(JoinStageTest, JoinsThreeInputs) { auto input_queue_1 = std::make_unique>(10); auto input_queue_2 = std::make_unique>(10); auto input_queue_3 = std::make_unique>(10); JoinPositions join_stage(config_); join_stage.SetInputs( {input_queue_1.get(), input_queue_2.get(), input_queue_3.get()}); join_stage.Start(); auto producer_1 = input_queue_1->CreateProducer(); auto producer_2 = input_queue_2->CreateProducer(); auto producer_3 = input_queue_3->CreateProducer(); producer_1.Put(CreateTestFrame(10)); producer_2.Put(CreateTestFrame(20)); producer_3.Put(CreateTestFrame(30)); absl::flat_hash_set received_versions; for (int i = 0; i < 3; ++i) { auto frame = join_stage.output_queue()->Get(); received_versions.insert(frame.version); } producer_1.Close(); producer_2.Close(); producer_3.Close(); join_stage.Stop(); EXPECT_EQ(received_versions.size(), 3u); EXPECT_TRUE(received_versions.contains(10)); EXPECT_TRUE(received_versions.contains(20)); EXPECT_TRUE(received_versions.contains(30)); } TEST_F(JoinStageTest, HandlesEmptyInputs) { auto input_queue_1 = std::make_unique>(10); auto input_queue_2 = std::make_unique>(10); JoinPositions join_stage(config_); join_stage.SetInputs({input_queue_1.get(), input_queue_2.get()}); join_stage.Start(); auto producer_1 = input_queue_1->CreateProducer(); auto producer_2 = input_queue_2->CreateProducer(); producer_1.Close(); producer_2.Close(); auto maybe_frame = join_stage.output_queue()->MaybeGet(); EXPECT_FALSE(maybe_frame.has_value()); join_stage.Stop(); } TEST_F(JoinStageTest, FlushesMetrics) { auto input_queue = std::make_unique>(10); JoinPositions join_stage(config_); join_stage.SetInputs({input_queue.get()}); join_stage.Start(); auto producer = input_queue->CreateProducer(); producer.Put(CreateTestFrame(1)); auto frame = join_stage.output_queue()->Get(); EXPECT_EQ(frame.version, 1u); producer.Close(); join_stage.Stop(); auto metrics = join_stage.FlushMetrics(); EXPECT_EQ(metrics.load_metrics_size(), 1); EXPECT_EQ(metrics.queue_metrics_size(), 1); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/position_sampling.cc ================================================ #include "loader/stages/position_sampling.h" #include namespace lczero { namespace training { float ComputePositionSamplingWeight(const FrameType& frame, const PositionSamplingConfig& config) { if (!config.has_diff_focus_q_weight() && !config.has_diff_focus_pol_scale()) { return config.default_weight(); } if (std::isnan(frame.orig_q)) return config.default_weight(); const float diff_q = std::abs(frame.best_q - frame.orig_q); const float q_weight = config.diff_focus_q_weight(); const float pol_scale = config.diff_focus_pol_scale(); const float total = (q_weight * diff_q + frame.policy_kld) / (q_weight + pol_scale); return std::min( std::pow(total * config.diff_focus_alpha() + config.diff_focus_beta(), config.diff_focus_gamma()), config.diff_focus_tau()); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/position_sampling.h ================================================ #pragma once #include "loader/frame_type.h" #include "proto/data_loader_config.pb.h" namespace lczero { namespace training { float ComputePositionSamplingWeight(const FrameType& frame, const PositionSamplingConfig& config); } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/shuffling_chunk_pool.cc ================================================ #include "loader/stages/shuffling_chunk_pool.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "loader/chunk_source/chunk_source.h" #include "loader/data_loader_metrics.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/position_sampling.h" #include "proto/data_loader_config.pb.h" #include "utils/thread_pool.h" namespace lczero { namespace training { thread_local absl::BitGen ShufflingChunkPool::bitgen_{absl::MakeSeedSeq()}; ShufflingChunkPool::ShufflingChunkPool(const ShufflingChunkPoolConfig& config) : primary_output_name_(config.output().name()), primary_output_queue_( config.output().queue_capacity(), ToOverflowBehavior(config.output().overflow_behavior())), chunk_pool_size_(config.chunk_pool_size()), config_(config), source_ingestion_pool_(config.source_ingestion_threads(), ThreadPoolOptions{}, stop_source_), chunk_loading_pool_(config.chunk_loading_threads(), ThreadPoolOptions{}, stop_source_), caching_pool_(config.has_cachehit_output() ? config.caching_threads() : 0, ThreadPoolOptions{}, stop_source_) { if (config.has_cachehit_output()) { cachehit_output_name_ = config.cachehit_output().name(); cachehit_output_queue_.emplace( config.cachehit_output().queue_capacity(), ToOverflowBehavior(config.cachehit_output().overflow_behavior())); if (primary_output_name_ == *cachehit_output_name_) { throw std::runtime_error(absl::StrCat( "ShufflingChunkPool output names must be different, got: '", primary_output_name_, "'")); } } LOG(INFO) << "Initializing ShufflingChunkPool with pool size " << config.chunk_pool_size(); } ShufflingChunkPool::~ShufflingChunkPool() { Stop(); } void ShufflingChunkPool::SetInputs(absl::Span inputs) { if (inputs.size() != 1 && inputs.size() != 2) { throw std::runtime_error(absl::StrCat( "ShufflingChunkPool expects 1 or 2 inputs, got ", inputs.size())); } if (inputs.size() == 2 && !cachehit_output_queue_.has_value()) { throw std::runtime_error( "ShufflingChunkPool received 2 inputs but cachehit_output is not " "configured"); } if (inputs.size() == 1 && cachehit_output_queue_.has_value()) { throw std::runtime_error( "ShufflingChunkPool has cachehit_output configured but received only " "1 input"); } primary_input_queue_ = dynamic_cast*>(inputs[0]); if (!primary_input_queue_) { throw std::runtime_error("ShufflingChunkPool primary input type mismatch"); } if (inputs.size() == 2) { cache_request_queue_ = dynamic_cast*>(inputs[1]); if (!cache_request_queue_) { throw std::runtime_error( "ShufflingChunkPool cache request input type mismatch"); } } } QueueBase* ShufflingChunkPool::GetOutput(std::string_view name) { if (name == primary_output_name_) return &primary_output_queue_; if (cachehit_output_name_.has_value() && name == *cachehit_output_name_) { return &*cachehit_output_queue_; } std::string available = absl::StrCat("'", primary_output_name_, "'"); if (cachehit_output_name_.has_value()) { absl::StrAppend(&available, ", '", *cachehit_output_name_, "'"); } throw std::runtime_error(absl::StrCat("ShufflingChunkPool unknown output '", name, "'. Available outputs: ", available)); } void ShufflingChunkPool::Start() { LOG(INFO) << "Starting ShufflingChunkPool initialization thread."; initialization_thread_ = std::jthread([this]() { try { LOG(INFO) << "Starting ShufflingChunkPool with pool size " << config_.chunk_pool_size(); std::vector> uninitialized_sources = InitializeChunkSources(); ProcessInputFiles(std::move(uninitialized_sources)); // Start input processing worker that continuously processes new files. for (size_t i = 0; i < source_ingestion_pool_.num_threads(); ++i) { auto* context = source_ingestion_thread_contexts_ .emplace_back(std::make_unique()) .get(); source_ingestion_pool_.Enqueue( [this, context](std::stop_token stop_token) { SourceIngestionWorker(stop_token, context); }); } // Start output workers after everything is fully initialized. LOG(INFO) << "ShufflingChunkPool initialization done, starting workers"; for (size_t i = 0; i < chunk_loading_pool_.num_threads(); ++i) { auto* context = chunk_loading_thread_contexts_ .emplace_back(std::make_unique()) .get(); chunk_loading_pool_.Enqueue( [this, context](std::stop_token stop_token) { OutputWorker(stop_token, context); }); } // Start caching workers if configured. if (cachehit_output_queue_.has_value()) { for (size_t i = 0; i < caching_pool_.num_threads(); ++i) { auto* context = caching_thread_contexts_ .emplace_back(std::make_unique()) .get(); caching_pool_.Enqueue([this, context](std::stop_token stop_token) { CachingWorker(stop_token, context); }); } } } catch (const QueueClosedException&) { LOG(INFO) << "ShufflingChunkPool initialization interrupted, input " "queue closed."; output_queue()->Close(); } catch (const std::exception& e) { LOG(ERROR) << "ShufflingChunkPool initialization failed: " << e.what(); output_queue()->Close(); } }); } void ShufflingChunkPool::Stop() { if (stop_source_.stop_requested()) return; LOG(INFO) << "Stopping ShufflingChunkPool."; stop_source_.request_stop(); if (initialization_thread_.joinable()) { initialization_thread_.request_stop(); initialization_thread_.join(); } source_ingestion_pool_.Shutdown(); chunk_loading_pool_.Shutdown(); if (cachehit_output_queue_) caching_pool_.Shutdown(); output_queue()->Close(); if (cachehit_output_queue_) cachehit_output_queue_->Close(); LOG(INFO) << "ShufflingChunkPool stopped."; } std::vector> ShufflingChunkPool::InitializeChunkSources() { std::vector> uninitialized_sources; // Read from input queue until kInitialScanComplete. while (true) { auto chunk_source_with_phase = input_queue()->Get(); if (chunk_source_with_phase.message_type == FilePathProvider::MessageType::kInitialScanComplete) { LOG(INFO) << "ShufflingChunkPool received initial scan completion marker."; break; } if (chunk_source_with_phase.message_type == FilePathProvider::MessageType::kFile) { // Add ChunkSource to uninitialized sources. uninitialized_sources.push_back( std::move(chunk_source_with_phase.source)); } } LOG(INFO) << "ShufflingChunkPool initial directory walk produced " << uninitialized_sources.size() << " chunk source candidate(s)."; // Sort in descending order (newest first). std::sort(uninitialized_sources.begin(), uninitialized_sources.end(), [](const auto& a, const auto& b) { return a->GetChunkSortKey() > b->GetChunkSortKey(); }); std::atomic total_chunks = 0; size_t sources_to_keep = 0; // Process sources sequentially until we have enough chunks. std::string current_anchor; { absl::MutexLock lock(&anchor_mutex_); current_anchor = anchor_; } for (auto& source : uninitialized_sources) { if (output_queue()->IsClosed()) { LOG(INFO) << "Output queue closed, stopping source ingestion."; break; } if (total_chunks >= chunk_pool_size_) break; // Count chunks immediately; constructors have already prepared metadata. const size_t chunk_count = source->GetChunkCount(); total_chunks += chunk_count; // Count chunks since anchor during initial load. if (source->GetChunkSortKey() > current_anchor) { chunks_since_anchor_ += chunk_count; } LOG_EVERY_N_SEC(INFO, 4) << "Loaded so far: " << total_chunks.load() << "; new: " << chunks_since_anchor_; ++sources_to_keep; } LOG(INFO) << "ShufflingChunkPool indexed " << total_chunks.load() << " chunk(s) across " << sources_to_keep << " source(s) during startup."; if (total_chunks < chunk_pool_size_ && !output_queue()->IsClosed()) { LOG(ERROR) << "ShufflingChunkPool startup chunk requirement not met: " << total_chunks.load() << " < " << chunk_pool_size_; } // Trim the vector to only keep the sources we need. uninitialized_sources.resize(sources_to_keep); return uninitialized_sources; } void ShufflingChunkPool::ProcessInputFiles( std::vector> uninitialized_sources) { // Initialize chunk sources from the initial scan. size_t initial_window_sources = 0; size_t initial_total_chunks = 0; { absl::MutexLock lock(&chunk_sources_mutex_); size_t start_chunk_index = 0; // Newest sources first, so we add in reverse order. std::for_each(uninitialized_sources.rbegin(), uninitialized_sources.rend(), [this, &start_chunk_index](auto& source) { const size_t count = source->GetChunkCount(); auto item = std::make_shared(); item->start_chunk_index = start_chunk_index; item->source = std::move(source); item->use_counts = std::vector(count, 0); item->weight = std::vector(count, -1.0f); item->cache = std::vector>( cachehit_output_queue_.has_value() ? count : 0); chunk_sources_.push_back(std::move(item)); start_chunk_index += chunk_sources_.back()->source->GetChunkCount(); }); // Initialize stream shuffler with the initial bounds. if (!chunk_sources_.empty()) { size_t total_chunks = chunk_sources_.back()->start_chunk_index + chunk_sources_.back()->source->GetChunkCount(); // Set bounds to provide the last chunk_pool_size_ chunks. size_t lower_bound = total_chunks > chunk_pool_size_ ? total_chunks - chunk_pool_size_ : 0; stream_shuffler_.SetLowerBound(lower_bound); stream_shuffler_.SetUpperBound(total_chunks); initial_total_chunks = total_chunks; } initial_window_sources = chunk_sources_.size(); } LOG(INFO) << "ShufflingChunkPool initial window ready with " << initial_window_sources << " source(s) totaling " << initial_total_chunks << " chunk(s)."; // Log anchor and sources after initial scan completion. { absl::MutexLock anchor_lock(&anchor_mutex_); LOG(INFO) << "Current anchor: '" << anchor_ << "'"; absl::MutexLock sources_lock(&chunk_sources_mutex_); std::vector> sources_after_anchor; for (const auto& item : chunk_sources_) { if (item->source->GetChunkSortKey() > anchor_) { sources_after_anchor.push_back(item); } } LOG(INFO) << sources_after_anchor.size() << " chunk source(s) after anchor, " << chunks_since_anchor_ << " total chunks since anchor"; const size_t to_log = std::min(sources_after_anchor.size(), size_t(20)); for (size_t i = 0; i < to_log; ++i) { LOG(INFO) << " Source [" << (i + 1) << "/" << sources_after_anchor.size() << "]: key='" << sources_after_anchor[i]->source->GetChunkSortKey() << "', chunks=" << sources_after_anchor[i]->source->GetChunkCount(); } } if (initial_total_chunks == 0) { throw std::runtime_error( "ShufflingChunkPool requires at least one chunk during startup."); } } void ShufflingChunkPool::SourceIngestionWorker( std::stop_token stop_token, SourceIngestionThreadContext* context) { try { while (true) { auto chunk_source_with_phase = [&]() { LoadMetricPauser pauser(context->load_metric_updater); return input_queue()->Get(stop_token); }(); if (chunk_source_with_phase.message_type == FilePathProvider::MessageType::kFile) { // Ingest the new chunk source. auto source = std::move(chunk_source_with_phase.source); size_t chunk_count = source->GetChunkCount(); absl::MutexLock lock(&chunk_sources_mutex_); chunks_since_anchor_ += chunk_count; AddNewChunkSource(std::move(source)); } } } catch (const QueueClosedException&) { LOG(INFO) << "SourceIngestionWorker stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "SourceIngestionWorker stopping, request cancelled."; } } void ShufflingChunkPool::OutputWorker(std::stop_token stop_token, ChunkLoadingThreadContext* context) { // Create a local producer for this worker auto primary_producer = output_queue()->CreateProducer(); std::optionalCreateProducer())> cachehit_producer; if (cachehit_output_queue_.has_value()) { cachehit_producer.emplace(cachehit_output_queue_->CreateProducer()); } try { while (true) { auto result = GetNextChunkData(); if (!result) { if (output_queue()->IsClosed()) break; continue; } LoadMetricPauser pauser(context->load_metric_updater); if (std::holds_alternative(*result)) { primary_producer.Put(std::move(std::get(*result)), stop_token); } else { cachehit_producer->Put(std::move(std::get(*result)), stop_token); } } } catch (const QueueClosedException&) { LOG(INFO) << "OutputWorker stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "OutputWorker stopping, request cancelled."; } catch (const std::exception& e) { LOG(FATAL) << "OutputWorker encountered an error: " << e.what(); } } void ShufflingChunkPool::CachingWorker(std::stop_token stop_token, CachingThreadContext* context) { constexpr double kTheta = 0.99; double reminder = 0.0; double exponential_avg_probability = 1.0; try { while (true) { auto cache_request = [&]() { LoadMetricPauser pauser(context->load_metric_updater); return cache_request_queue_->Get(stop_token); }(); std::shared_ptr source_item; float max_weight = 0.0f; size_t local_index = 0; { absl::MutexLock lock(&chunk_sources_mutex_); // Find the chunk source containing this global index. auto it = absl::c_lower_bound( chunk_sources_, cache_request.global_index, [](const auto& item, size_t chunk_idx) { return item->start_chunk_index + item->source->GetChunkCount() <= chunk_idx; }); if (it == chunk_sources_.end() || cache_request.global_index < (*it)->start_chunk_index) { chunk_source_not_found_.fetch_add(1, std::memory_order_acq_rel); continue; } source_item = *it; max_weight = max_weight_; local_index = cache_request.global_index - source_item->start_chunk_index; } absl::MutexLock item_lock(&source_item->mutex); assert(local_index < source_item->use_counts.size()); // Check use_count match. if (source_item->use_counts[local_index] != cache_request.next_use) { mismatched_use_counts_.fetch_add(1, std::memory_order_acq_rel); continue; } // Compute how many positions to cache. const float weight = source_item->weight[local_index]; assert(weight >= 0.0f); const double probability = ComputeHanseProbability(weight, max_weight); exponential_avg_probability = exponential_avg_probability * (1.0 - kTheta) + probability * kTheta; const double n = (probability * config_.position_cache_size() / chunk_pool_size_ / exponential_avg_probability) + reminder; reminder = n - std::floor(n); const size_t positions_to_cache = static_cast(std::floor(n)); // Traverse and extend the cache chain. std::unique_ptr* current = &source_item->cache[local_index]; for (size_t i = 0; i < positions_to_cache; ++i) { if (*current) { current = &(*current)->next; continue; } if (i >= cache_request.items.size()) break; auto node = std::make_unique(); node->frame = cache_request.items[i]; *current = std::move(node); current = &(*current)->next; newly_cached_.fetch_add(1, std::memory_order_acq_rel); cached_positions_.fetch_add(1, std::memory_order_acq_rel); } const size_t dropped = cache_request.items.size() > positions_to_cache ? cache_request.items.size() - positions_to_cache : 0; dropped_cache_positions_.fetch_add(dropped, std::memory_order_acq_rel); } } catch (const QueueClosedException&) { LOG(INFO) << "CachingWorker stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "CachingWorker stopping, request cancelled."; } } struct ShufflingChunkPool::ChunkData { std::vector data; std::string sort_key; size_t local_index = 0; size_t global_index = 0; uint32_t use_count = 0; std::shared_ptr source_item; }; std::optional> ShufflingChunkPool::GetNextChunkData() { while (true) { ChunkData chunk_data; const ChunkStatus status = GetChunkInfo(chunk_data); if (status == ChunkStatus::kEnd) return std::nullopt; if (status == ChunkStatus::kRetry) continue; const bool hanse_enabled = config_.hanse_sampling_threshold() > 0; if (hanse_enabled && !HanseAccept(chunk_data)) continue; { absl::MutexLock lock(&chunk_data.source_item->mutex); assert(chunk_data.source_item->use_counts.size() > chunk_data.local_index); chunk_data.use_count = chunk_data.source_item->use_counts[chunk_data.local_index]++; if (cachehit_output_queue_.has_value()) { auto& cache_chain = chunk_data.source_item->cache[chunk_data.local_index]; if (cache_chain) { cache_hits_.fetch_add(1, std::memory_order_acq_rel); cached_positions_.fetch_sub(1, std::memory_order_acq_rel); FrameType cached_frame = cache_chain->frame; cache_chain = std::move(cache_chain->next); return cached_frame; } cache_misses_.fetch_add(1, std::memory_order_acq_rel); } } if (chunk_data.data.empty() && !LoadChunkData(chunk_data)) continue; TrainingChunk chunk; chunk.sort_key = std::move(chunk_data.sort_key); chunk.index_within_sort_key = chunk_data.local_index; chunk.use_count = chunk_data.use_count; chunk.global_index = chunk_data.global_index; chunk.frames = std::move(chunk_data.data); return chunk; } } bool ShufflingChunkPool::LoadChunkData(ChunkData& chunk_data) { std::optional> data = chunk_data.source_item->source->GetChunkData(chunk_data.local_index); if (!data || data->empty()) { absl::MutexLock lock(&chunk_data.source_item->mutex); chunk_data.source_item->dropped_chunks.insert(chunk_data.local_index); dropped_chunks_metric_.fetch_add(1, std::memory_order_acq_rel); return false; } chunk_data.data = std::move(*data); return true; } ShufflingChunkPool::ChunkStatus ShufflingChunkPool::GetChunkInfo( ChunkData& out_chunk_data) { std::shared_ptr source_item; { absl::MutexLock lock(&chunk_sources_mutex_); std::optional chunk_index = stream_shuffler_.GetNextItem(); if (!chunk_index && !chunk_sources_.empty()) { size_t total_chunks = chunk_sources_.back()->start_chunk_index + chunk_sources_.back()->source->GetChunkCount(); size_t lower_bound = total_chunks > chunk_pool_size_ ? total_chunks - chunk_pool_size_ : chunk_sources_.front()->start_chunk_index; stream_shuffler_.Reset(lower_bound, total_chunks); reshuffles_.fetch_add(1, std::memory_order_acq_rel); chunk_index = stream_shuffler_.GetNextItem(); } if (!chunk_index) return ChunkStatus::kEnd; auto it = absl::c_lower_bound(chunk_sources_, *chunk_index, [](const auto& item, size_t chunk_idx) { return item->start_chunk_index + item->source->GetChunkCount() <= chunk_idx; }); if (ABSL_PREDICT_FALSE(it == chunk_sources_.end() || *chunk_index < (*it)->start_chunk_index)) { LOG(WARNING) << "Chunk index " << *chunk_index << " out of range for available chunk sources."; return ChunkStatus::kRetry; } source_item = *it; out_chunk_data.local_index = *chunk_index - source_item->start_chunk_index; out_chunk_data.sort_key = source_item->source->GetChunkSortKey(); out_chunk_data.global_index = *chunk_index; } { absl::MutexLock lock(&source_item->mutex); if (source_item->dropped_chunks.contains(out_chunk_data.local_index)) { return ChunkStatus::kRetry; } } out_chunk_data.source_item = std::move(source_item); return ChunkStatus::kOk; } double ShufflingChunkPool::ComputeHanseProbability(float weight, float max_weight) const { if (max_weight <= 0.0f) return 1.0; return std::pow(weight / max_weight, config_.hanse_sampling_gamma()); } float ShufflingChunkPool::ComputeChunkWeight( absl::Span frames) const { return absl::c_accumulate(frames, 0.0f, [this](float sum, const auto& frame) { return sum + ComputePositionSamplingWeight(frame, config_.position_sampling()); }); } bool ShufflingChunkPool::HanseAccept(ChunkData& chunk_data) { assert(chunk_data.source_item); float weight = -1.0f; { absl::MutexLock lock(&chunk_data.source_item->mutex); assert(chunk_data.source_item->weight.size() > chunk_data.local_index); weight = chunk_data.source_item->weight[chunk_data.local_index]; } if (weight < 0.0f) { if (chunk_data.data.empty() && !LoadChunkData(chunk_data)) return false; weight = ComputeChunkWeight(chunk_data.data); { absl::MutexLock pool_lock(&chunk_sources_mutex_); max_weight_ = std::max(max_weight_, weight); AddSample(chunk_weight_stats_, static_cast(weight)); } { absl::MutexLock lock(&chunk_data.source_item->mutex); const float cached_weight = chunk_data.source_item->weight[chunk_data.local_index]; if (cached_weight < 0.0f) { chunk_data.source_item->weight[chunk_data.local_index] = weight; } else { weight = cached_weight; } } hanse_cache_misses_.fetch_add(1, std::memory_order_acq_rel); } else { hanse_cache_hits_.fetch_add(1, std::memory_order_acq_rel); } float max_weight = 0.0f; { absl::MutexLock lock(&chunk_sources_mutex_); max_weight = max_weight_; } const double p = ComputeHanseProbability(weight, max_weight); const double u = absl::Uniform(bitgen_, 0.0, 1.0); if (u >= p) { hanse_rejected_.fetch_add(1, std::memory_order_acq_rel); return false; } return true; } void ShufflingChunkPool::AddNewChunkSource(std::unique_ptr source) ABSL_EXCLUSIVE_LOCKS_REQUIRED(chunk_sources_mutex_) { // Add new chunk source to the end of the deque. size_t old_upper_bound = 0; if (!chunk_sources_.empty()) { const auto& last_source = chunk_sources_.back(); old_upper_bound = last_source->start_chunk_index + last_source->source->GetChunkCount(); } size_t count = source->GetChunkCount(); auto item = std::make_shared(); item->start_chunk_index = old_upper_bound; item->source = std::move(source); item->use_counts = std::vector(count, 0); item->weight = std::vector(count, -1.0f); item->cache = std::vector>(cachehit_output_queue_.has_value() ? count : 0); chunk_sources_.push_back(std::move(item)); // Calculate current window bounds. size_t new_upper_bound = chunk_sources_.back()->start_chunk_index + chunk_sources_.back()->source->GetChunkCount(); // Remove old chunks if window exceeds chunk_pool_size_. while (!chunk_sources_.empty() && chunk_sources_.size() > 1) { size_t window_start = chunk_sources_.front()->start_chunk_index + chunk_sources_.front()->source->GetChunkCount(); size_t window_size = new_upper_bound - window_start; if (window_size < chunk_pool_size_) break; // Count cached positions in the evicted source. if (cachehit_output_queue_.has_value()) { size_t evicted_cached = 0; absl::MutexLock item_lock(&chunk_sources_.front()->mutex); for (const auto& cache_chain : chunk_sources_.front()->cache) { const CacheNode* node = cache_chain.get(); while (node) { ++evicted_cached; node = node->next.get(); } } cached_positions_.fetch_sub(evicted_cached, std::memory_order_acq_rel); } // Remove the oldest chunk source (front of deque). chunk_sources_.pop_front(); } // Update stream shuffler bounds with the sliding window. size_t window_start = chunk_sources_.front()->start_chunk_index; size_t new_lower_bound = new_upper_bound > chunk_pool_size_ ? new_upper_bound - chunk_pool_size_ : window_start; stream_shuffler_.SetUpperBound(new_upper_bound); stream_shuffler_.SetLowerBound(new_lower_bound); } StageMetricProto ShufflingChunkPool::FlushMetrics() { StageMetricProto stage_metric; // Aggregate source ingestion load metrics from all ingestion threads. LoadMetricProto ingestion_load; ingestion_load.set_name("source_ingestion"); for (const auto& context : source_ingestion_thread_contexts_) { UpdateFrom(ingestion_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(ingestion_load); // Aggregate chunk loading load metrics from all chunk loading threads. LoadMetricProto chunk_loading_load; chunk_loading_load.set_name("chunk_loading"); for (const auto& context : chunk_loading_thread_contexts_) { UpdateFrom(chunk_loading_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(chunk_loading_load); // Get chunk sources statistics and pool state. { absl::MutexLock lock(&chunk_sources_mutex_); auto* chunk_sources_metric = stage_metric.add_gauge_metrics(); chunk_sources_metric->set_name("chunk_sources"); chunk_sources_metric->set_value( static_cast(chunk_sources_.size())); size_t upper = 0; size_t current = 0; if (!chunk_sources_.empty()) { const auto& first = chunk_sources_.front(); const auto& last = chunk_sources_.back(); upper = last->start_chunk_index + last->source->GetChunkCount(); current = upper - first->start_chunk_index; } auto* current_chunks_metric = stage_metric.add_gauge_metrics(); current_chunks_metric->set_name("chunks_current"); current_chunks_metric->set_value(static_cast(current)); current_chunks_metric->set_capacity( static_cast(chunk_pool_size_)); auto* total_chunks_metric = stage_metric.add_gauge_metrics(); total_chunks_metric->set_name("chunks_total"); total_chunks_metric->set_value(static_cast(upper)); } // Get anchor-related metrics. { absl::MutexLock lock(&anchor_mutex_); auto* chunks_since_anchor_metric = stage_metric.add_gauge_metrics(); chunks_since_anchor_metric->set_name("chunks_since_anchor"); chunks_since_anchor_metric->set_value(chunks_since_anchor_); stage_metric.set_anchor(anchor_); } auto* dropped_metric = stage_metric.add_count_metrics(); dropped_metric->set_name("dropped"); dropped_metric->set_count( dropped_chunks_metric_.exchange(0, std::memory_order_acq_rel)); // Hanse sampling and shuffler metrics. { auto* hits = stage_metric.add_count_metrics(); hits->set_name("hanse_cache_hits"); hits->set_count(hanse_cache_hits_.exchange(0, std::memory_order_acq_rel)); auto* misses = stage_metric.add_count_metrics(); misses->set_name("hanse_cache_misses"); misses->set_count( hanse_cache_misses_.exchange(0, std::memory_order_acq_rel)); auto* rejected = stage_metric.add_count_metrics(); rejected->set_name("hanse_rejected"); rejected->set_count(hanse_rejected_.exchange(0, std::memory_order_acq_rel)); auto* resh = stage_metric.add_count_metrics(); resh->set_name("reshuffles"); resh->set_count(reshuffles_.exchange(0, std::memory_order_acq_rel)); } // Position cache metrics. if (cachehit_output_queue_.has_value()) { LoadMetricProto caching_load; caching_load.set_name("caching"); for (const auto& context : caching_thread_contexts_) { UpdateFrom(caching_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(caching_load); auto* cache_hits = stage_metric.add_count_metrics(); cache_hits->set_name("cache_hits"); cache_hits->set_count(cache_hits_.exchange(0, std::memory_order_acq_rel)); auto* cache_misses = stage_metric.add_count_metrics(); cache_misses->set_name("cache_misses"); cache_misses->set_count( cache_misses_.exchange(0, std::memory_order_acq_rel)); auto* mismatched = stage_metric.add_count_metrics(); mismatched->set_name("mismatched_use_counts"); mismatched->set_count( mismatched_use_counts_.exchange(0, std::memory_order_acq_rel)); auto* newly_cached = stage_metric.add_count_metrics(); newly_cached->set_name("newly_cached"); newly_cached->set_count( newly_cached_.exchange(0, std::memory_order_acq_rel)); auto* dropped = stage_metric.add_count_metrics(); dropped->set_name("dropped_cache_positions"); dropped->set_count( dropped_cache_positions_.exchange(0, std::memory_order_acq_rel)); auto* not_found = stage_metric.add_count_metrics(); not_found->set_name("chunk_source_not_found"); not_found->set_count( chunk_source_not_found_.exchange(0, std::memory_order_acq_rel)); auto* cached = stage_metric.add_gauge_metrics(); cached->set_name("cached_positions"); cached->set_value(cached_positions_.load(std::memory_order_acquire)); cached->set_capacity(config_.position_cache_size()); } { absl::MutexLock lock(&chunk_sources_mutex_); if (chunk_weight_stats_.count() > 0) { chunk_weight_stats_.set_name("chunk_weight"); UpdateFrom(*stage_metric.add_statistics_metrics(), chunk_weight_stats_); } chunk_weight_stats_.Clear(); } *stage_metric.add_queue_metrics() = MetricsFromQueue(primary_output_name_, *output_queue()); if (cachehit_output_queue_.has_value()) { *stage_metric.add_queue_metrics() = MetricsFromQueue(*cachehit_output_name_, *cachehit_output_queue_); } return stage_metric; } std::pair ShufflingChunkPool::ResetAnchor() { absl::MutexLock anchor_lock(&anchor_mutex_); absl::MutexLock sources_lock(&chunk_sources_mutex_); if (chunk_sources_.empty()) { int previous_count = chunks_since_anchor_.exchange(0); return {anchor_, previous_count}; } anchor_ = chunk_sources_.back()->source->GetChunkSortKey(); int previous_count = chunks_since_anchor_.exchange(0); return {anchor_, previous_count}; } int ShufflingChunkPool::ChunksSinceAnchor() { return chunks_since_anchor_; } std::string ShufflingChunkPool::CurrentAnchor() { absl::MutexLock lock(&anchor_mutex_); return anchor_; } void ShufflingChunkPool::SetAnchor(std::string_view anchor) { absl::MutexLock lock(&anchor_mutex_); anchor_ = anchor; } std::optional ShufflingChunkPool::Control( const StageControlRequest& request) { if (!request.has_chunk_pool_request()) { return std::nullopt; } const auto& chunk_request = request.chunk_pool_request(); StageControlResponse response; auto* chunk_response = response.mutable_chunk_pool_response(); if (chunk_request.reset_chunk_anchor()) { auto [anchor, chunks] = ResetAnchor(); chunk_response->set_chunk_anchor(anchor); chunk_response->set_chunks_since_anchor(chunks); return response; } if (chunk_request.has_set_chunk_anchor()) { SetAnchor(chunk_request.set_chunk_anchor()); chunk_response->set_chunk_anchor(chunk_request.set_chunk_anchor()); chunk_response->set_chunks_since_anchor(ChunksSinceAnchor()); return response; } chunk_response->set_chunk_anchor(CurrentAnchor()); chunk_response->set_chunks_since_anchor(ChunksSinceAnchor()); return response; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/shuffling_chunk_pool.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_set.h" #include "absl/random/random.h" #include "absl/synchronization/mutex.h" #include "loader/chunk_source/chunk_source.h" #include "loader/data_loader_metrics.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/stage.h" #include "loader/stages/training_chunk.h" #include "proto/data_loader_config.pb.h" #include "proto/stage_control.pb.h" #include "proto/training_metrics.pb.h" #include "utils/metrics/load_metric.h" #include "utils/queue.h" #include "utils/stream_shuffler.h" #include "utils/thread_pool.h" namespace lczero { namespace training { class ShufflingChunkPool : public Stage { public: explicit ShufflingChunkPool(const ShufflingChunkPoolConfig& config); ~ShufflingChunkPool(); void Start() override; void Stop() override; void SetInputs(absl::Span inputs) override; QueueBase* GetOutput(std::string_view name) override; StageMetricProto FlushMetrics() override; std::optional Control( const StageControlRequest& request) override; // Anchor management methods for tracking chunks since a specific point. std::pair ResetAnchor(); int ChunksSinceAnchor(); std::string CurrentAnchor(); void SetAnchor(std::string_view anchor); Queue* input_queue() { return primary_input_queue_; } Queue* output_queue() { return &primary_output_queue_; } private: struct CacheNode { FrameType frame; std::unique_ptr next; }; struct ChunkSourceItem { mutable absl::Mutex mutex; size_t start_chunk_index; std::unique_ptr source; absl::flat_hash_set dropped_chunks ABSL_GUARDED_BY(mutex); // Per-chunk counters and cached weights. std::vector use_counts ABSL_GUARDED_BY(mutex); std::vector weight ABSL_GUARDED_BY(mutex); std::vector> cache ABSL_GUARDED_BY(mutex); }; struct SourceIngestionThreadContext { LoadMetricUpdater load_metric_updater; }; struct ChunkLoadingThreadContext { LoadMetricUpdater load_metric_updater; }; struct CachingThreadContext { LoadMetricUpdater load_metric_updater; }; std::vector> InitializeChunkSources(); void ProcessInputFiles( std::vector> uninitialized_sources); void SourceIngestionWorker(std::stop_token stop_token, SourceIngestionThreadContext* context); void OutputWorker(std::stop_token stop_token, ChunkLoadingThreadContext* context); void CachingWorker(std::stop_token stop_token, CachingThreadContext* context); void AddNewChunkSource(std::unique_ptr source) ABSL_EXCLUSIVE_LOCKS_REQUIRED(chunk_sources_mutex_); std::optional> GetNextChunkData(); enum class ChunkStatus { kOk, kRetry, kEnd }; struct ChunkData; ChunkStatus GetChunkInfo(ChunkData& out_chunk_data); bool LoadChunkData(ChunkData& chunk_data); bool HanseAccept(ChunkData& chunk_data); float ComputeChunkWeight(absl::Span frames) const; double ComputeHanseProbability(float weight, float max_weight) const; Queue* primary_input_queue_ = nullptr; Queue* cache_request_queue_ = nullptr; std::string primary_output_name_; Queue primary_output_queue_; std::optional cachehit_output_name_; std::optional> cachehit_output_queue_; const size_t chunk_pool_size_; const ShufflingChunkPoolConfig config_; // stop_source_ must be declared before ThreadPools that reference it. std::stop_source stop_source_; ThreadPool source_ingestion_pool_; ThreadPool chunk_loading_pool_; ThreadPool caching_pool_; std::atomic dropped_chunks_metric_{0}; absl::Mutex chunk_sources_mutex_; std::deque> chunk_sources_ ABSL_GUARDED_BY(chunk_sources_mutex_); StreamShuffler stream_shuffler_ ABSL_GUARDED_BY(chunk_sources_mutex_); float max_weight_ ABSL_GUARDED_BY(chunk_sources_mutex_) = 0.0f; std::jthread initialization_thread_; std::vector> source_ingestion_thread_contexts_; std::vector> chunk_loading_thread_contexts_; std::vector> caching_thread_contexts_; // Anchor-related members for tracking chunks since a specific point. absl::Mutex anchor_mutex_; std::string anchor_ ABSL_GUARDED_BY(anchor_mutex_); std::atomic chunks_since_anchor_{0}; // Thread-local RNG for Hanse sampling. static thread_local absl::BitGen bitgen_; // Metrics counters. std::atomic hanse_cache_hits_{0}; std::atomic hanse_cache_misses_{0}; std::atomic hanse_rejected_{0}; std::atomic reshuffles_{0}; std::atomic cache_hits_{0}; std::atomic cache_misses_{0}; std::atomic mismatched_use_counts_{0}; std::atomic newly_cached_{0}; std::atomic dropped_cache_positions_{0}; std::atomic chunk_source_not_found_{0}; std::atomic cached_positions_{0}; StatisticsProtoDouble chunk_weight_stats_ ABSL_GUARDED_BY(chunk_sources_mutex_); }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/shuffling_chunk_pool_test.cc ================================================ // ABOUTME: Comprehensive unit tests for the ShufflingChunkPool class // ABOUTME: Tests chunk source management, output workers, and dynamic windowing #include "loader/stages/shuffling_chunk_pool.h" #include #include #include #include #include #include #include #include #include #include #include #include #include "loader/stages/training_chunk.h" namespace lczero { namespace training { namespace { template class PassthroughStage : public Stage { public: explicit PassthroughStage(Queue* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return StageMetricProto(); } QueueBase* GetOutput(std::string_view name = "") override { (void)name; return queue_; } void SetInputs(absl::Span inputs) override { if (!inputs.empty()) { throw std::runtime_error("PassthroughStage expects no inputs"); } } private: Queue* queue_; }; } // namespace // Mock ChunkSource for testing class MockChunkSource : public ChunkSource { public: MockChunkSource(const std::string& sort_key, size_t chunk_count) : sort_key_(sort_key), chunk_count_(chunk_count) {} std::string GetChunkSortKey() const override { return sort_key_; } size_t GetChunkCount() const override { return chunk_count_; } std::optional> GetChunkData(size_t index) override { if (index >= chunk_count_) { throw std::out_of_range("Chunk index out of range"); } FrameType frame{}; frame.version = static_cast(index); frame.input_format = 3; return std::vector{frame}; } private: std::string sort_key_; size_t chunk_count_; }; class InvalidChunkSource : public ChunkSource { public: explicit InvalidChunkSource(std::string sort_key) : sort_key_(std::move(sort_key)) {} std::string GetChunkSortKey() const override { return sort_key_; } size_t GetChunkCount() const override { return 2; } std::optional> GetChunkData(size_t index) override { if (index >= 2) { throw std::out_of_range("Chunk index out of range"); } if (index == 0) { return std::nullopt; } FrameType frame{}; frame.version = 42; return std::vector{frame}; } private: std::string sort_key_; }; class ShufflingChunkPoolTest : public ::testing::Test { protected: void SetUp() override { input_queue_ = std::make_unique>(100); input_producer_ = std::make_unique::Producer>( input_queue_->CreateProducer()); } void TearDown() override { // Close the producer to close the queue if (input_producer_) input_producer_.reset(); } // Helper to add a mock chunk source to the input queue void AddMockChunkSourceToQueue(const std::string& sort_key, size_t chunk_count, FilePathProvider::MessageType message_type = FilePathProvider::MessageType::kFile) { ChunkSourceWithPhase item; item.source = std::make_unique(sort_key, chunk_count); item.message_type = message_type; input_producer_->Put(std::move(item)); } void MarkInitialScanComplete() { ChunkSourceWithPhase item; item.source = nullptr; // No source for completion marker item.message_type = FilePathProvider::MessageType::kInitialScanComplete; input_producer_->Put(std::move(item)); } void CloseInputQueue() { if (input_producer_) input_producer_.reset(); } ShufflingChunkPoolConfig MakeConfig(int chunk_pool_size, int source_ingestion_threads = 1, int loading_threads = 1, int queue_capacity = 100) const { ShufflingChunkPoolConfig config; config.set_chunk_pool_size(chunk_pool_size); config.set_source_ingestion_threads(source_ingestion_threads); config.set_chunk_loading_threads(loading_threads); config.mutable_output()->set_queue_capacity(queue_capacity); return config; } std::unique_ptr> input_queue_; std::unique_ptr::Producer> input_producer_; }; TEST_F(ShufflingChunkPoolTest, ConstructorCreatesOutputQueue) { // Add some mock chunk sources with enough chunks AddMockChunkSourceToQueue("source1", 50); AddMockChunkSourceToQueue("source2", 60); MarkInitialScanComplete(); auto config = MakeConfig(20); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); auto* output_queue = shuffling_chunk_pool.output_queue(); // Close input queue to stop input worker from waiting CloseInputQueue(); EXPECT_NE(output_queue, nullptr); EXPECT_EQ(output_queue->Capacity(), 100); // Drain output queue to prevent workers from blocking try { while (output_queue->Size() > 0) { output_queue->Get(); } } catch (const QueueClosedException&) { // Queue closed, that's fine } } TEST_F(ShufflingChunkPoolTest, HandlesEmptyInputQueue) { // Only mark scan complete, no chunk sources MarkInitialScanComplete(); auto config = MakeConfig(20); // Constructor should now succeed (initialization is asynchronous) ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // The initialization thread should handle the error case auto* output_queue = shuffling_chunk_pool.output_queue(); // Give the initialization thread time to complete and discover the error std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Close input queue to clean up CloseInputQueue(); // Output queue should exist but be closed to signal startup failure when no // chunks were found. EXPECT_NE(output_queue, nullptr); EXPECT_TRUE(output_queue->IsClosed()); EXPECT_EQ(output_queue->Size(), 0u); } TEST_F(ShufflingChunkPoolTest, FlushMetricsHandlesEmptyChunkSources) { const int chunk_pool_size = 32; auto config = MakeConfig(chunk_pool_size); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); auto metrics = shuffling_chunk_pool.FlushMetrics(); bool found_current = false; bool found_total = false; for (const auto& metric : metrics.gauge_metrics()) { if (metric.name() == "chunks_current") { found_current = true; EXPECT_EQ(metric.value(), 0u); EXPECT_EQ(metric.capacity(), static_cast(chunk_pool_size)); } else if (metric.name() == "chunks_total") { found_total = true; EXPECT_EQ(metric.value(), 0u); } } EXPECT_TRUE(found_current) << "FlushMetrics should emit chunks_current metric when empty."; EXPECT_TRUE(found_total) << "FlushMetrics should emit chunks_total metric when empty."; } TEST_F(ShufflingChunkPoolTest, FlushMetricsReportsWindowAndTotalCounts) { AddMockChunkSourceToQueue("initial", 30); MarkInitialScanComplete(); const int chunk_pool_size = 20; ShufflingChunkPool shuffling_chunk_pool(MakeConfig(chunk_pool_size)); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); auto* output_queue = shuffling_chunk_pool.output_queue(); output_queue->WaitForSizeAtLeast(1); uint64_t current_count = 0; uint64_t total_count = 0; uint64_t current_capacity = 0; bool found_metrics = false; for (int attempt = 0; attempt < 50 && !found_metrics; ++attempt) { auto metrics = shuffling_chunk_pool.FlushMetrics(); bool has_current = false; bool has_total = false; for (const auto& metric : metrics.gauge_metrics()) { if (metric.name() == "chunks_current") { has_current = true; current_count = metric.value(); current_capacity = metric.capacity(); } else if (metric.name() == "chunks_total") { has_total = true; total_count = metric.value(); } } if (has_current && has_total) { found_metrics = true; break; } std::this_thread::sleep_for(std::chrono::milliseconds(10)); } ASSERT_TRUE(found_metrics) << "FlushMetrics should report both chunks_current and chunks_total."; EXPECT_EQ(current_count, 30u); EXPECT_EQ(current_capacity, static_cast(chunk_pool_size)); EXPECT_EQ(total_count, 30u); CloseInputQueue(); } TEST_F(ShufflingChunkPoolTest, ProcessesInitialScanChunkSources) { // Create mock chunk sources with enough chunks AddMockChunkSourceToQueue("source1", 30); AddMockChunkSourceToQueue("source2", 40); AddMockChunkSourceToQueue("source3", 50); MarkInitialScanComplete(); auto config = MakeConfig(20); // Test that constructor completes and processes mock chunk sources EXPECT_NO_THROW({ ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // Close input queue to stop input worker from waiting CloseInputQueue(); auto* output_queue = shuffling_chunk_pool.output_queue(); EXPECT_NE(output_queue, nullptr); }); } TEST_F(ShufflingChunkPoolTest, OutputWorkerProducesChunks) { // Create mock chunk sources AddMockChunkSourceToQueue("source1", 10, FilePathProvider::MessageType::kFile); AddMockChunkSourceToQueue("source2", 15, FilePathProvider::MessageType::kFile); MarkInitialScanComplete(); auto config = MakeConfig(20); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // Close input queue to stop input worker from waiting CloseInputQueue(); auto* output_queue = shuffling_chunk_pool.output_queue(); // Wait for output workers to produce at least one chunk output_queue->WaitForSizeAtLeast(1); // Should have some chunks available EXPECT_GT(output_queue->Size(), 0); // Get a chunk and verify it's from our mock sources auto chunk = output_queue->Get(); EXPECT_FALSE(chunk.frames.empty()); EXPECT_TRUE(chunk.sort_key == "source1" || chunk.sort_key == "source2"); EXPECT_EQ(chunk.frames.size(), 1); EXPECT_EQ(chunk.frames.front().version, static_cast(chunk.index_within_sort_key)); EXPECT_EQ(chunk.use_count, 0u); } TEST_F(ShufflingChunkPoolTest, DropsInvalidChunks) { ChunkSourceWithPhase invalid_source; invalid_source.source = std::make_unique("invalid_source"); invalid_source.message_type = FilePathProvider::MessageType::kFile; input_producer_->Put(std::move(invalid_source)); MarkInitialScanComplete(); auto config = MakeConfig(2, /*source_ingestion_threads=*/1, /*loading_threads=*/1, /*queue_capacity=*/10); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // Close input queue to stop input worker from waiting CloseInputQueue(); auto* output_queue = shuffling_chunk_pool.output_queue(); output_queue->WaitForSizeAtLeast(1); auto chunk = output_queue->Get(); EXPECT_EQ(chunk.sort_key, "invalid_source"); EXPECT_EQ(chunk.index_within_sort_key, 1); EXPECT_EQ(chunk.use_count, 0u); ASSERT_EQ(chunk.frames.size(), 1); EXPECT_EQ(chunk.frames.front().version, 42); uint64_t dropped_latest = 0; bool found_dropped = false; for (int attempt = 0; attempt < 50 && !found_dropped; ++attempt) { auto metrics = shuffling_chunk_pool.FlushMetrics(); for (const auto& metric : metrics.count_metrics()) { if (metric.name() == "dropped" && metric.count() > 0) { dropped_latest = metric.count(); found_dropped = true; break; } } if (!found_dropped) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } ASSERT_TRUE(found_dropped) << "dropped chunk metrics should be reported"; EXPECT_GE(dropped_latest, 1u); } TEST_F(ShufflingChunkPoolTest, NewChunkSourceProcessing) { // Start with initial scan and one chunk source - use enough chunks to satisfy // window AddMockChunkSourceToQueue("initial", 120); // More chunks than window MarkInitialScanComplete(); auto config = MakeConfig(20); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // Verify chunks are being produced from initial sources auto* output_queue = shuffling_chunk_pool.output_queue(); output_queue->WaitForSizeAtLeast(1); EXPECT_NE(output_queue, nullptr); EXPECT_GT(output_queue->Size(), 0); // Add a new chunk source after initialization AddMockChunkSourceToQueue("new_source", 30, FilePathProvider::MessageType::kFile); // Close input queue to stop input worker from waiting for more CloseInputQueue(); // The chunk set should still be functional and continue producing chunks // from both the initial and new sources EXPECT_GT(output_queue->Size(), 0); } TEST_F(ShufflingChunkPoolTest, ChunkWindowManagement) { // Create more chunks than the window size AddMockChunkSourceToQueue("source1", 30); AddMockChunkSourceToQueue("source2", 30); AddMockChunkSourceToQueue("source3", 30); MarkInitialScanComplete(); auto config = MakeConfig(50); // Should only keep sources that fit in the window EXPECT_NO_THROW({ ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // Close input queue to stop input worker from waiting CloseInputQueue(); auto* output_queue = shuffling_chunk_pool.output_queue(); EXPECT_NE(output_queue, nullptr); }); } // Test the ShufflingChunkPoolConfig structure TEST_F(ShufflingChunkPoolTest, ChunkSorting) { // Add chunk sources in non-sorted order (by sort key) AddMockChunkSourceToQueue("source_b", 20); AddMockChunkSourceToQueue("source_a", 25); AddMockChunkSourceToQueue("source_c", 30); MarkInitialScanComplete(); auto config = MakeConfig(70); // ShufflingChunkPool should handle sorting internally (newest first) EXPECT_NO_THROW({ ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // Close input queue to stop input worker from waiting CloseInputQueue(); auto* output_queue = shuffling_chunk_pool.output_queue(); EXPECT_NE(output_queue, nullptr); }); } TEST_F(ShufflingChunkPoolTest, StreamShufflerResetWhenExhausted) { // Create a small chunk source to quickly exhaust the shuffler AddMockChunkSourceToQueue("source1", 3); // Only 3 chunks for faster testing MarkInitialScanComplete(); auto config = MakeConfig(3, /*source_ingestion_threads=*/1, /*loading_threads=*/1, /*queue_capacity=*/100); // Large enough ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); auto* output_queue = shuffling_chunk_pool.output_queue(); // Collect chunks continuously and count total chunks received struct ChunkRecord { std::string sort_key; size_t index; uint32_t use_count; }; std::vector all_chunks_received; // Wait for and collect chunks to test shuffler reset for (size_t i = 0; i < 8; ++i) { output_queue->WaitForSizeAtLeast(1); auto chunk = output_queue->Get(); all_chunks_received.push_back( {chunk.sort_key, chunk.index_within_sort_key, chunk.use_count}); } std::set> unique_chunks; bool seen_reuse = false; for (const auto& record : all_chunks_received) { unique_chunks.emplace(record.sort_key, record.index); if (record.use_count > 0) { seen_reuse = true; } } // Close input queue to clean up try { CloseInputQueue(); } catch (const QueueClosedException&) { // Already closed, that's fine } // We should see all 3 unique chunks from our source EXPECT_EQ(unique_chunks.size(), 3) << "Should see all unique chunks"; // If reset works properly, we should receive more than 3 total chunks // (since chunks will repeat after shuffler reset) EXPECT_GT(all_chunks_received.size(), 3) << "Should get more than 3 chunks total due to shuffler reset, got " << all_chunks_received.size() << " chunks"; EXPECT_TRUE(seen_reuse) << "Expect at least one chunk to report a reuse count"; } TEST_F(ShufflingChunkPoolTest, HanseMetrics_NoRejection_CacheAndReshuffles) { // Single chunk so we will continually reuse the same chunk. AddMockChunkSourceToQueue("source1", 1); MarkInitialScanComplete(); auto config = MakeConfig(1, /*source_ingestion_threads=*/1, /*loading_threads=*/1, /*queue_capacity=*/100); // Enable Hanse sampling with p == 1 to avoid rejections. config.set_hanse_sampling_threshold(1); ShufflingChunkPool pool(config); pool.SetInputs({input_queue_.get()}); pool.Start(); auto* output_queue = pool.output_queue(); // Wait for multiple outputs to exercise cache hits and reshuffles. output_queue->WaitForSizeAtLeast(3); // Drain a few items. for (int i = 0; i < 3; ++i) { auto chunk = output_queue->Get(); EXPECT_EQ(chunk.frames.size(), 1u); } // Close input to avoid lingering. CloseInputQueue(); // Flush metrics and validate Hanse counters and reshuffles. auto metrics = pool.FlushMetrics(); uint64_t cache_hits = 0, cache_misses = 0, rejected = 0, reshuffles = 0; for (const auto& m : metrics.count_metrics()) { if (m.name() == "hanse_cache_hits") cache_hits = m.count(); if (m.name() == "hanse_cache_misses") cache_misses = m.count(); if (m.name() == "hanse_rejected") rejected = m.count(); if (m.name() == "reshuffles") reshuffles = m.count(); } // First access computes and caches num_records => 1 miss, then hits. EXPECT_EQ(cache_misses, 1u); EXPECT_GE(cache_hits, 1u); // With threshold=1 and one frame, p = 1 => no rejections. EXPECT_EQ(rejected, 0u); // Single chunk repeatedly consumed forces reshuffles. EXPECT_GT(reshuffles, 0u); } TEST_F(ShufflingChunkPoolTest, ExplicitClose) { // Create chunk sources AddMockChunkSourceToQueue("source1", 20); AddMockChunkSourceToQueue("source2", 30); MarkInitialScanComplete(); auto config = MakeConfig(40); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); auto* output_queue = shuffling_chunk_pool.output_queue(); // Wait for workers to produce some chunks output_queue->WaitForSizeAtLeast(1); // Verify output queue is working before close EXPECT_GT(output_queue->Size(), 0); // Explicitly stop the chunk set shuffling_chunk_pool.Stop(); // Drain all remaining items from the queue while (output_queue->Size() > 0) { output_queue->Get(); } // Now the queue should be closed and empty, so Get() should throw EXPECT_THROW(output_queue->Get(), QueueClosedException); CloseInputQueue(); } TEST_F(ShufflingChunkPoolTest, CloseStopsOutputWorkers) { // Create chunk sources AddMockChunkSourceToQueue("source1", 15); MarkInitialScanComplete(); auto config = MakeConfig(15, /*source_ingestion_threads=*/1, /*loading_threads=*/2, /*queue_capacity=*/50); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); auto* output_queue = shuffling_chunk_pool.output_queue(); // Wait for workers to produce chunks output_queue->WaitForSizeAtLeast(1); size_t chunks_before_close = output_queue->Size(); // Stop the chunk set shuffling_chunk_pool.Stop(); // Drain any remaining chunks from the queue try { while (output_queue->Size() > 0) { output_queue->Get(); } } catch (const QueueClosedException&) { // Expected when queue is empty and closed } // Should have had chunks before close EXPECT_GT(chunks_before_close, 0) << "Should have had chunks before close"; CloseInputQueue(); } TEST_F(ShufflingChunkPoolTest, CloseIsIdempotent) { // Create chunk sources AddMockChunkSourceToQueue("source1", 20); MarkInitialScanComplete(); auto config = MakeConfig(20); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); // Stop multiple times - should not crash or cause issues EXPECT_NO_THROW(shuffling_chunk_pool.Stop()); EXPECT_NO_THROW(shuffling_chunk_pool.Stop()); EXPECT_NO_THROW(shuffling_chunk_pool.Stop()); CloseInputQueue(); } TEST_F(ShufflingChunkPoolTest, DestructorCallsClose) { // Create chunk sources AddMockChunkSourceToQueue("source1", 20); MarkInitialScanComplete(); auto config = MakeConfig(20); // Test that destructor calls Close() and properly shuts down { ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); auto* output_queue = shuffling_chunk_pool.output_queue(); // Wait for workers to produce some chunks output_queue->WaitForSizeAtLeast(1); EXPECT_GT(output_queue->Size(), 0); // Close input queue before destructor to allow threads to finish CloseInputQueue(); // ShufflingChunkPool destructor should be called here, which calls Stop() // and waits for all threads to finish } // Test passes if destructor completes without hanging // (we can't test the queue state after destruction since it's destroyed) } TEST_F(ShufflingChunkPoolTest, InputQueueClosureDoesNotCloseOutputQueue) { // Create chunk sources AddMockChunkSourceToQueue("source1", 30); MarkInitialScanComplete(); auto config = MakeConfig(30); ShufflingChunkPool shuffling_chunk_pool(config); shuffling_chunk_pool.SetInputs({input_queue_.get()}); shuffling_chunk_pool.Start(); auto* output_queue = shuffling_chunk_pool.output_queue(); // Wait for workers to produce some chunks output_queue->WaitForSizeAtLeast(1); EXPECT_GT(output_queue->Size(), 0); // Close input queue (simulating end of file discovery) CloseInputQueue(); // Output queue should still be functional - workers should continue // producing chunks from existing chunk sources // Should still be able to get chunks (queue not closed) EXPECT_NO_THROW(output_queue->Get()); // Explicitly stop to clean up shuffling_chunk_pool.Stop(); } TEST_F(ShufflingChunkPoolTest, BasicAnchorFunctionality) { AddMockChunkSourceToQueue("source1", 20); MarkInitialScanComplete(); auto config = MakeConfig(20); ShufflingChunkPool pool(config); pool.SetInputs({input_queue_.get()}); pool.Start(); // Test initial state EXPECT_EQ(pool.ChunksSinceAnchor(), 0); EXPECT_EQ(pool.CurrentAnchor(), ""); // Test SetAnchor and CurrentAnchor pool.SetAnchor("test_anchor_key"); EXPECT_EQ(pool.CurrentAnchor(), "test_anchor_key"); EXPECT_EQ(pool.ChunksSinceAnchor(), 0); // Should still be 0 // Test setting different anchor pool.SetAnchor("another_key"); EXPECT_EQ(pool.CurrentAnchor(), "another_key"); CloseInputQueue(); } TEST_F(ShufflingChunkPoolTest, ResetAnchor) { AddMockChunkSourceToQueue("source1", 20); MarkInitialScanComplete(); auto config = MakeConfig(20); ShufflingChunkPool pool(config); pool.SetInputs({input_queue_.get()}); pool.Start(); // Wait for initialization to complete pool.output_queue()->WaitForSizeAtLeast(1); // Now test ResetAnchor auto [anchor, count_before] = pool.ResetAnchor(); EXPECT_FALSE(anchor.empty()); // Should have the chunk key EXPECT_EQ(pool.CurrentAnchor(), anchor); EXPECT_EQ(pool.ChunksSinceAnchor(), 0); // Should be reset to 0 CloseInputQueue(); } TEST_F(ShufflingChunkPoolTest, AnchorCounterIncrement) { // Don't mark initial scan complete yet - we'll add sources one by one auto config = MakeConfig(20); // Start with some initial sources and complete scan AddMockChunkSourceToQueue("source1", 20); MarkInitialScanComplete(); ShufflingChunkPool pool(config); pool.SetInputs({input_queue_.get()}); pool.Start(); // Set anchor to a key that won't match our new sources pool.SetAnchor("non_matching_key"); // Wait for initial load to complete pool.output_queue()->WaitForSizeAtLeast(1); // Now add new sources (these should increment the counter) // Note: We can't add more sources after initial scan complete in the current // setup So we'll test the counter after the initial load int final_count = pool.ChunksSinceAnchor(); // Counter should have incremented during initial load since anchor doesn't // match EXPECT_GT(final_count, 0); EXPECT_EQ(pool.CurrentAnchor(), "non_matching_key"); // Anchor unchanged CloseInputQueue(); } TEST_F(ShufflingChunkPoolTest, AnchorCounterResetDuringInitialLoad) { // Test the special case where anchor is encountered during initial backward // processing AddMockChunkSourceToQueue("source_c", 10); // newest AddMockChunkSourceToQueue("source_b", 15); // middle AddMockChunkSourceToQueue("source_a", 20); // oldest auto config = MakeConfig(45); ShufflingChunkPool pool(config); pool.SetInputs({input_queue_.get()}); pool.Start(); // Set anchor to middle source before marking scan complete pool.SetAnchor("source_b"); // Mark scan complete to trigger initial processing MarkInitialScanComplete(); // Wait for initial load to complete pool.output_queue()->WaitForSizeAtLeast(1); int final_count = pool.ChunksSinceAnchor(); // Should only count chunks from source_c (10 chunks) since it is newer than // the anchor. EXPECT_EQ(final_count, 10); EXPECT_EQ(pool.CurrentAnchor(), "source_b"); CloseInputQueue(); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/shuffling_frame_sampler.cc ================================================ #include "loader/stages/shuffling_frame_sampler.h" #include #include "absl/algorithm/container.h" #include "absl/log/log.h" #include "absl/random/uniform_int_distribution.h" #include "loader/data_loader_metrics.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" namespace lczero { namespace training { ShufflingFrameSampler::ShufflingFrameSampler( const ShufflingFrameSamplerConfig& config) : SingleInputStage(config), SingleOutputStage(config.output()), reservoir_size_per_thread_(config.reservoir_size_per_thread()), thread_pool_(config.threads(), ThreadPoolOptions{}) { LOG(INFO) << "Initializing ShufflingFrameSampler with " << config.threads() << " threads, reservoir size " << config.reservoir_size_per_thread(); // Initialize thread contexts but don't start worker threads yet. thread_contexts_.reserve(config.threads()); for (size_t i = 0; i < config.threads(); ++i) { thread_contexts_.push_back(std::make_unique()); } } ShufflingFrameSampler::~ShufflingFrameSampler() { Stop(); } void ShufflingFrameSampler::Start() { LOG(INFO) << "Starting ShufflingFrameSampler worker threads."; for (size_t i = 0; i < thread_contexts_.size(); ++i) { thread_pool_.Enqueue([this, i](std::stop_token stop_token) { Worker(stop_token, thread_contexts_[i].get()); }); } } void ShufflingFrameSampler::Stop() { if (thread_pool_.stop_token().stop_requested()) return; LOG(INFO) << "Stopping ShufflingFrameSampler."; thread_pool_.Shutdown(); output_queue()->Close(); LOG(INFO) << "ShufflingFrameSampler stopped."; } void ShufflingFrameSampler::Worker(std::stop_token stop_token, ThreadContext* context) { // Create producer early so that if input queue closes during reservoir // prefilling, the producer will be destroyed and close the output queue. auto producer = output_queue()->CreateProducer(); absl::FixedArray reservoir(reservoir_size_per_thread_); try { // Phase 1: Prefill the reservoir LOG(INFO) << "ShufflingFrameSampler worker prefilling reservoir"; absl::c_generate(reservoir, [this, context, stop_token]() { LoadMetricPauser pauser(context->load_metric_updater); return input_queue()->Get(stop_token); }); // Phase 2: Main sampling loop MainSamplingLoop(stop_token, reservoir, producer, context); } catch (const QueueClosedException&) { LOG(INFO) << "ShufflingFrameSampler worker stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "ShufflingFrameSampler worker stopping, request cancelled."; } } void ShufflingFrameSampler::MainSamplingLoop( std::stop_token stop_token, absl::FixedArray& reservoir, Queue::Producer& producer, ThreadContext* context) { absl::uniform_int_distribution dist(0, reservoir.size() - 1); while (true) { const size_t random_index = dist(gen_); { LoadMetricPauser pauser(context->load_metric_updater); producer.Put(std::move(reservoir[random_index]), stop_token); } { LoadMetricPauser pauser(context->load_metric_updater); reservoir[random_index] = input_queue()->Get(stop_token); } } } StageMetricProto ShufflingFrameSampler::FlushMetrics() { StageMetricProto stage_metric; LoadMetricProto aggregated_load; aggregated_load.set_name("load"); for (const auto& context : thread_contexts_) { UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(aggregated_load); *stage_metric.add_queue_metrics() = MetricsFromQueue("output", *output_queue()); return stage_metric; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/shuffling_frame_sampler.h ================================================ // ABOUTME: Stage that provides shuffled frames using reservoir sampling. // ABOUTME: Takes FrameType frames and outputs them in randomized order. #pragma once #include #include #include #include #include #include "absl/container/fixed_array.h" #include "absl/random/random.h" #include "loader/data_loader_metrics.h" #include "loader/frame_type.h" #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Worker that implements reservoir sampling for training frames. // Takes FrameType frames as input and outputs them in shuffled order // using reservoir sampling algorithm. class ShufflingFrameSampler : public SingleInputStage, public SingleOutputStage { public: using InputType = FrameType; using OutputType = FrameType; explicit ShufflingFrameSampler(const ShufflingFrameSamplerConfig& config); ~ShufflingFrameSampler(); void Start() override; void Stop() override; StageMetricProto FlushMetrics() override; private: struct ThreadContext { LoadMetricUpdater load_metric_updater; }; void Worker(std::stop_token stop_token, ThreadContext* context); void MainSamplingLoop(std::stop_token stop_token, absl::FixedArray& reservoir, Queue::Producer& producer, ThreadContext* context); size_t reservoir_size_per_thread_; absl::BitGen gen_; // thread_contexts_ must be declared before thread_pool_ to ensure // thread_pool_ is destroyed first (stopping threads before contexts). std::vector> thread_contexts_; ThreadPool thread_pool_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/shuffling_frame_sampler_test.cc ================================================ #include "loader/stages/shuffling_frame_sampler.h" #include #include #include "gtest/gtest.h" #include "libs/lc0/src/trainingdata/trainingdata_v6.h" #include "utils/queue.h" namespace lczero { namespace training { namespace { template class PassthroughStage : public Stage { public: explicit PassthroughStage(Queue* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return StageMetricProto(); } QueueBase* GetOutput(std::string_view name = "") override { (void)name; return queue_; } void SetInputs(absl::Span inputs) override { if (!inputs.empty()) { throw std::runtime_error("PassthroughStage expects no inputs"); } } private: Queue* queue_; }; } // namespace class ShufflingFrameSamplerTest : public ::testing::Test { protected: void SetUp() override { input_queue_ = std::make_unique>(100); config_.set_reservoir_size_per_thread(10); // Small size for testing config_.mutable_output()->set_queue_capacity(20); } FrameType CreateTestFrame(uint32_t version) { FrameType frame{}; frame.version = version; frame.input_format = 3; frame.root_q = 0.5f; return frame; } std::unique_ptr> input_queue_; ShufflingFrameSamplerConfig config_; }; TEST_F(ShufflingFrameSamplerTest, OutputsNoFramesWithSmallInput) { ShufflingFrameSampler sampler(config_); sampler.SetInputs({input_queue_.get()}); sampler.Start(); // Send 5 frames (less than reservoir size) auto producer = input_queue_->CreateProducer(); std::vector input_versions = {1, 2, 3, 4, 5}; for (auto version : input_versions) { producer.Put(CreateTestFrame(version)); } producer.Close(); // Collect all output frames std::set output_versions; try { while (true) { auto frame = sampler.output_queue()->Get(); output_versions.insert(frame.version); } } catch (const QueueClosedException&) { // Expected when queue is closed } // With fewer inputs than reservoir size, no frames should be output // (they remain in the reservoir) EXPECT_EQ(output_versions.size(), 0); } TEST_F(ShufflingFrameSamplerTest, OutputsFramesWithLargeInput) { ShufflingFrameSampler sampler(config_); sampler.SetInputs({input_queue_.get()}); sampler.Start(); // Send 20 frames (more than reservoir size of 10) auto producer = input_queue_->CreateProducer(); std::vector input_versions; for (uint32_t i = 1; i <= 20; ++i) { input_versions.push_back(i); producer.Put(CreateTestFrame(i)); } producer.Close(); // Collect all output frames std::set output_versions; try { while (true) { auto frame = sampler.output_queue()->Get(); output_versions.insert(frame.version); } } catch (const QueueClosedException&) { // Expected when queue is closed } // Should output exactly 11 frames (10 during sampling + 1 final frame before // queue closes) EXPECT_EQ(output_versions.size(), 11); // All output frames should be from the input set for (auto version : output_versions) { EXPECT_TRUE(std::find(input_versions.begin(), input_versions.end(), version) != input_versions.end()); } } TEST_F(ShufflingFrameSamplerTest, HandlesEmptyInput) { ShufflingFrameSampler sampler(config_); sampler.SetInputs({input_queue_.get()}); sampler.Start(); // Close input queue without sending data input_queue_->Close(); // Should not output any frames EXPECT_THROW(sampler.output_queue()->Get(), QueueClosedException); } TEST_F(ShufflingFrameSamplerTest, HandlesExactReservoirSize) { ShufflingFrameSampler sampler(config_); sampler.SetInputs({input_queue_.get()}); sampler.Start(); // Send exactly reservoir_size_per_thread frames auto producer = input_queue_->CreateProducer(); std::vector input_versions; for (uint32_t i = 1; i <= config_.reservoir_size_per_thread(); ++i) { input_versions.push_back(i); producer.Put(CreateTestFrame(i)); } producer.Close(); // Collect all output frames std::set output_versions; try { while (true) { auto frame = sampler.output_queue()->Get(); output_versions.insert(frame.version); } } catch (const QueueClosedException&) { // Expected when queue is closed } // With exactly reservoir size frames, 1 frame should be output // (fills reservoir, then queue closes during first sampling attempt) EXPECT_EQ(output_versions.size(), 1); } TEST_F(ShufflingFrameSamplerTest, PreservesFrameData) { config_.set_reservoir_size_per_thread(2); ShufflingFrameSampler sampler(config_); sampler.SetInputs({input_queue_.get()}); sampler.Start(); auto producer = input_queue_->CreateProducer(); // Create frames with specific data - need more than reservoir size FrameType frame1 = CreateTestFrame(100); frame1.root_q = 0.1f; frame1.input_format = 1; FrameType frame2 = CreateTestFrame(200); frame2.root_q = 0.2f; frame2.input_format = 2; FrameType frame3 = CreateTestFrame(300); frame3.root_q = 0.3f; frame3.input_format = 3; producer.Put(frame1); producer.Put(frame2); producer.Put(frame3); // This will cause frame1 to be output producer.Close(); // Verify frame data is preserved std::vector output_frames; try { while (true) { output_frames.push_back(sampler.output_queue()->Get()); } } catch (const QueueClosedException&) { // Expected } EXPECT_EQ(output_frames.size(), 2); // Should be frames that were displaced from the reservoir during sampling std::set output_frame_versions; for (const auto& frame : output_frames) { output_frame_versions.insert(frame.version); // Verify frame data is preserved if (frame.version == 100) { EXPECT_EQ(frame.root_q, 0.1f); EXPECT_EQ(frame.input_format, 1); } else if (frame.version == 200) { EXPECT_EQ(frame.root_q, 0.2f); EXPECT_EQ(frame.input_format, 2); } } } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/simple_chunk_extractor.cc ================================================ #include "loader/stages/simple_chunk_extractor.h" #include #include #include #include "loader/data_loader_metrics.h" namespace lczero { namespace training { SimpleChunkExtractor::SimpleChunkExtractor( const SimpleChunkExtractorConfig& config) : SingleInputStage( config), SingleOutputStage(config.output()), bitgen_(absl::MakeSeedSeq()) {} SimpleChunkExtractor::~SimpleChunkExtractor() { Stop(); } void SimpleChunkExtractor::Start() { thread_pool_.Enqueue( [this](std::stop_token stop_token) { Worker(stop_token); }); } void SimpleChunkExtractor::Stop() { if (thread_pool_.stop_token().stop_requested()) return; LOG(INFO) << "Stopping SimpleChunkExtractor."; thread_pool_.Shutdown(); output_queue()->Close(); } void SimpleChunkExtractor::Worker(std::stop_token stop_token) { auto producer = output_queue()->CreateProducer(); try { while (true) { auto item = input_queue()->Get(stop_token); if (item.message_type != FilePathProvider::MessageType::kFile || !item.source) { continue; } ProcessSource(producer, std::move(item.source), stop_token); } } catch (const QueueClosedException&) { } } void SimpleChunkExtractor::ProcessSource( Queue::Producer& producer, std::unique_ptr source, std::stop_token stop_token) { const size_t chunk_count = source->GetChunkCount(); if (chunk_count == 0) return; std::vector indices(chunk_count); std::iota(indices.begin(), indices.end(), 0); absl::c_shuffle(indices, bitgen_); const std::string sort_key = source->GetChunkSortKey(); for (size_t idx : indices) { if (auto chunk = LoadChunk(*source, sort_key, idx)) { producer.Put(std::move(*chunk), stop_token); ++chunks_processed_; } } ++sources_processed_; } std::optional SimpleChunkExtractor::LoadChunk( ChunkSource& source, const std::string& sort_key, size_t index) { auto data = source.GetChunkData(index); if (!data || data->empty()) { ++chunks_dropped_; return std::nullopt; } TrainingChunk chunk; chunk.sort_key = sort_key; chunk.index_within_sort_key = index; chunk.global_index = chunks_processed_; chunk.use_count = 0; chunk.frames = std::move(*data); return chunk; } StageMetricProto SimpleChunkExtractor::FlushMetrics() { StageMetricProto metric; auto add_count = [&](const char* name, std::atomic& counter) { auto* m = metric.add_count_metrics(); m->set_name(name); m->set_count(counter.exchange(0)); }; add_count("chunks_processed", chunks_processed_); add_count("chunks_dropped", chunks_dropped_); add_count("sources_processed", sources_processed_); *metric.add_queue_metrics() = MetricsFromQueue("output", *output_queue()); return metric; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/simple_chunk_extractor.h ================================================ #pragma once #include #include #include #include #include #include #include "absl/random/random.h" #include "loader/chunk_source/chunk_source.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/stage.h" #include "loader/stages/training_chunk.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/queue.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Single-threaded stage that shuffles chunks within each source. class SimpleChunkExtractor : public SingleInputStage, public SingleOutputStage { public: explicit SimpleChunkExtractor(const SimpleChunkExtractorConfig& config); ~SimpleChunkExtractor(); void Start() override; void Stop() override; StageMetricProto FlushMetrics() override; private: void Worker(std::stop_token stop_token); void ProcessSource(Queue::Producer& producer, std::unique_ptr source, std::stop_token stop_token); std::optional LoadChunk(ChunkSource& source, const std::string& sort_key, size_t index); std::atomic chunks_processed_{0}; std::atomic chunks_dropped_{0}; std::atomic sources_processed_{0}; absl::BitGen bitgen_; ThreadPool thread_pool_{1}; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/simple_chunk_extractor_test.cc ================================================ #include "loader/stages/simple_chunk_extractor.h" #include #include #include #include #include "loader/chunk_source/chunk_source.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/file_path_provider.h" #include "loader/stages/stage.h" #include "loader/stages/training_chunk.h" #include "proto/data_loader_config.pb.h" #include "utils/queue.h" namespace lczero { namespace training { namespace { // Mock chunk source for testing. class MockChunkSource : public ChunkSource { public: MockChunkSource(std::string sort_key, size_t chunk_count) : sort_key_(std::move(sort_key)), chunk_count_(chunk_count) { // Pre-generate chunk data. for (size_t i = 0; i < chunk_count; ++i) { chunks_.emplace_back(10); // 10 frames per chunk. } } std::string GetChunkSortKey() const override { return sort_key_; } size_t GetChunkCount() const override { return chunk_count_; } std::optional> GetChunkData(size_t index) override { if (index >= chunks_.size()) return std::nullopt; return chunks_[index]; } private: std::string sort_key_; size_t chunk_count_; std::vector> chunks_; }; class SimpleChunkExtractorTest : public ::testing::Test { protected: void SetUp() override { // Create input queue. input_queue_ = std::make_unique>(10); // Add a dummy stage to the registry. StageConfig dummy_config; dummy_config.set_name("dummy_input"); registry_.AddStage("dummy_input", std::make_unique(input_queue_.get())); // Create the shuffler config. SimpleChunkExtractorConfig config; config.set_input("dummy_input"); config.set_queue_capacity(10); // Create the shuffler stage. shuffler_ = std::make_unique(config, registry_); } void TearDown() override { if (shuffler_) { shuffler_->Stop(); } input_queue_->Close(); } // Helper class to provide a dummy stage for the registry. class DummyStage : public Stage { public: explicit DummyStage(QueueBase* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return {}; } QueueBase* GetOutput(std::string_view name = "") override { (void)name; return queue_; } private: QueueBase* queue_; }; StageRegistry registry_; std::unique_ptr> input_queue_; std::unique_ptr shuffler_; }; TEST_F(SimpleChunkExtractorTest, ProcessesSingleSource) { shuffler_->Start(); auto producer = input_queue_->CreateProducer(); // Send a chunk source with 5 chunks. auto source = std::make_unique("source1", 5); producer.Put({.source = std::move(source), .message_type = FilePathProvider::MessageType::kFile}); // Close input to signal completion. input_queue_->Close(); // Collect all output chunks. auto* output = static_cast*>(shuffler_->GetOutput()); std::vector chunks; while (true) { try { chunks.push_back(output->Get()); } catch (const QueueClosedException&) { break; } } // Should receive exactly 5 chunks. EXPECT_EQ(chunks.size(), 5); // All chunks should have the same sort_key. for (const auto& chunk : chunks) { EXPECT_EQ(chunk.sort_key, "source1"); EXPECT_EQ(chunk.frames.size(), 10); // 10 frames per chunk. } // Check that all chunk indices are present (though order is shuffled). std::vector indices; for (const auto& chunk : chunks) { indices.push_back(chunk.index_within_sort_key); } std::sort(indices.begin(), indices.end()); EXPECT_THAT(indices, ::testing::ElementsAre(0, 1, 2, 3, 4)); } TEST_F(SimpleChunkExtractorTest, ProcessesMultipleSources) { shuffler_->Start(); auto producer = input_queue_->CreateProducer(); // Send two chunk sources. producer.Put({.source = std::make_unique("source1", 3), .message_type = FilePathProvider::MessageType::kFile}); producer.Put({.source = std::make_unique("source2", 2), .message_type = FilePathProvider::MessageType::kFile}); input_queue_->Close(); // Collect all output chunks. auto* output = static_cast*>(shuffler_->GetOutput()); std::vector chunks; while (true) { try { chunks.push_back(output->Get()); } catch (const QueueClosedException&) { break; } } // Should receive 3 + 2 = 5 chunks total. EXPECT_EQ(chunks.size(), 5); // Count chunks per source. size_t source1_count = 0; size_t source2_count = 0; for (const auto& chunk : chunks) { if (chunk.sort_key == "source1") { ++source1_count; } else if (chunk.sort_key == "source2") { ++source2_count; } } EXPECT_EQ(source1_count, 3); EXPECT_EQ(source2_count, 2); } TEST_F(SimpleChunkExtractorTest, SkipsNonFileMessages) { shuffler_->Start(); auto producer = input_queue_->CreateProducer(); // Send a non-file message. producer.Put( {.source = nullptr, .message_type = FilePathProvider::MessageType::kInitialScanComplete}); // Send a file message. producer.Put({.source = std::make_unique("source1", 2), .message_type = FilePathProvider::MessageType::kFile}); input_queue_->Close(); // Collect all output chunks. auto* output = static_cast*>(shuffler_->GetOutput()); std::vector chunks; while (true) { try { chunks.push_back(output->Get()); } catch (const QueueClosedException&) { break; } } // Should only receive 2 chunks from the file message. EXPECT_EQ(chunks.size(), 2); } TEST_F(SimpleChunkExtractorTest, MetricsAreRecorded) { shuffler_->Start(); auto producer = input_queue_->CreateProducer(); producer.Put({.source = std::make_unique("source1", 3), .message_type = FilePathProvider::MessageType::kFile}); input_queue_->Close(); // Wait for processing to complete. auto* output = static_cast*>(shuffler_->GetOutput()); while (true) { try { output->Get(); } catch (const QueueClosedException&) { break; } } // Flush metrics. auto metrics = shuffler_->FlushMetrics(); EXPECT_GT(metrics.count_metrics_size(), 0); // Check that chunks_processed metric exists. bool found_chunks_processed = false; for (const auto& metric : metrics.count_metrics()) { if (metric.name() == "chunks_processed") { EXPECT_EQ(metric.count(), 3); found_chunks_processed = true; } } EXPECT_TRUE(found_chunks_processed); } } // namespace } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/stage.cc ================================================ #include "loader/stages/stage.h" #include namespace lczero { namespace training { void StageRegistry::AddStage(std::string_view stage_name, std::unique_ptr stage) { if (absl::c_find_if(stages_, [&](const auto& pair) { return pair.first == stage_name; }) != stages_.end()) { throw std::runtime_error( absl::StrCat("Duplicate stage name detected: ", stage_name)); } stages_.emplace_back(stage_name, std::move(stage)); } QueueBase* StageRegistry::GetStageOutput(std::string_view stage_name) const { auto [actual_stage_name, output_name] = [&stage_name]() { size_t dot_pos = stage_name.find('.'); return dot_pos == std::string_view::npos ? std::pair{stage_name, std::string_view{}} : std::pair{stage_name.substr(0, dot_pos), stage_name.substr(dot_pos + 1)}; }(); auto it = absl::c_find_if(stages_, [&](const auto& pair) { return pair.first == actual_stage_name; }); return it != stages_.end() ? it->second->GetOutput(output_name) : nullptr; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/stage.h ================================================ #pragma once #include #include #include #include #include #include #include #include "absl/strings/str_cat.h" #include "proto/data_loader_config.pb.h" #include "proto/stage_control.pb.h" #include "proto/training_metrics.pb.h" #include "utils/queue.h" namespace lczero { namespace training { // Base interface implemented by all loader stages. class Stage { public: virtual ~Stage() = default; // Starts background workers owned by the stage. virtual void Start() = 0; // Requests the stage to stop and join background work. virtual void Stop() = 0; // Flushes stage-specific metrics and returns a snapshot. virtual StageMetricProto FlushMetrics() = 0; // Returns the output queue for downstream stages. virtual QueueBase* GetOutput(std::string_view name = "") = 0; // Sets the input queues for this stage. Called after construction but before // Start(). virtual void SetInputs(absl::Span inputs) = 0; // Handles control-plane messages specific to the stage. virtual std::optional Control( const StageControlRequest& request) { (void)request; return std::nullopt; } }; class StageRegistry { public: // Registers a new stage with the given name and takes ownership of it. void AddStage(std::string_view stage_name, std::unique_ptr stage); // Returns the output queue for the specified stage. // If stage_name contains a dot (e.g., "stage.output"), splits it into stage // name and output name, passing the output name to Stage::GetOutput(). // Returns nullptr if the stage is not found. QueueBase* GetStageOutput(std::string_view stage_name) const; template Queue* GetTypedStageOutput(std::string_view stage_name) const { QueueBase* raw_queue = GetStageOutput(stage_name); if (raw_queue == nullptr) return nullptr; auto* typed_queue = dynamic_cast*>(raw_queue); if (!typed_queue) { throw std::runtime_error( absl::StrCat("Stage output type mismatch for stage: ", stage_name)); } return typed_queue; } size_t size() const { return stages_.size(); } const std::vector>>& stages() const { return stages_; } private: std::vector>> stages_; }; // Helper to convert QueueConfig::OverflowBehavior to OverflowBehavior. inline OverflowBehavior ToOverflowBehavior( QueueConfig::OverflowBehavior behavior) { switch (behavior) { case QueueConfig::BLOCK: return OverflowBehavior::BLOCK; case QueueConfig::DROP_NEW: return OverflowBehavior::DROP_NEW; case QueueConfig::KEEP_NEWEST: return OverflowBehavior::KEEP_NEWEST; } throw std::runtime_error(absl::StrCat("Unknown OverflowBehavior value: ", static_cast(behavior))); } // Helper for stages that consume a single upstream queue. template class SingleInputStage : virtual public Stage { public: void SetInputs(absl::Span inputs) override { if (inputs.size() != 1) { throw std::runtime_error(absl::StrCat( "SingleInputStage expects exactly 1 input, got ", inputs.size())); } auto* typed_queue = dynamic_cast*>(inputs[0]); if (!typed_queue) throw std::runtime_error("Input queue type mismatch"); input_queue_ = typed_queue; } protected: explicit SingleInputStage(const ConfigT&) : input_queue_(nullptr) {} Queue* input_queue() { return input_queue_; } private: Queue* input_queue_; }; // Helper for stages that produce a single output queue. template class SingleOutputStage : virtual public Stage { public: Queue* output_queue() { return &output_queue_; } QueueBase* GetOutput(std::string_view name = "") override { if (name != output_name_) { throw std::runtime_error(absl::StrCat("Output name '", name, "' does not match configured '", output_name_, "'")); } return &output_queue_; } protected: explicit SingleOutputStage(const QueueConfig& config) : output_name_(config.name()), output_queue_(config.queue_capacity(), ToOverflowBehavior(config.overflow_behavior())) {} private: std::string output_name_; Queue output_queue_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/stage_factory.cc ================================================ #include "loader/stages/stage_factory.h" #include #include #include "loader/stages/chunk_rescorer.h" #include "loader/stages/chunk_source_loader.h" #include "loader/stages/chunk_source_splitter.h" #include "loader/stages/chunk_unpacker.h" #include "loader/stages/file_path_provider.h" #include "loader/stages/join_stage.h" #include "loader/stages/shuffling_chunk_pool.h" #include "loader/stages/shuffling_frame_sampler.h" #include "loader/stages/simple_chunk_extractor.h" #include "loader/stages/tensor_generator.h" namespace lczero { namespace training { namespace { int CountStageConfigs(const StageConfig& config) { return static_cast(config.has_file_path_provider()) + static_cast(config.has_chunk_source_loader()) + static_cast(config.has_shuffling_chunk_pool()) + static_cast(config.has_chunk_rescorer()) + static_cast(config.has_chunk_unpacker()) + static_cast(config.has_shuffling_frame_sampler()) + static_cast(config.has_tensor_generator()) + static_cast(config.has_chunk_source_splitter()) + static_cast(config.has_simple_chunk_extractor()) + static_cast(config.has_join_positions()); } } // namespace std::unique_ptr CreateStage(const StageConfig& config) { if (CountStageConfigs(config) != 1) { throw std::runtime_error( "StageConfig must have exactly one stage-specific config set."); } if (config.has_file_path_provider()) { return std::make_unique(config.file_path_provider()); } if (config.has_chunk_source_loader()) { return std::make_unique(config.chunk_source_loader()); } if (config.has_shuffling_chunk_pool()) { return std::make_unique(config.shuffling_chunk_pool()); } if (config.has_chunk_rescorer()) { return std::make_unique(config.chunk_rescorer()); } if (config.has_chunk_unpacker()) { return std::make_unique(config.chunk_unpacker()); } if (config.has_shuffling_frame_sampler()) { return std::make_unique( config.shuffling_frame_sampler()); } if (config.has_tensor_generator()) { return std::make_unique(config.tensor_generator()); } if (config.has_chunk_source_splitter()) { return std::make_unique( config.chunk_source_splitter()); } if (config.has_simple_chunk_extractor()) { return std::make_unique( config.simple_chunk_extractor()); } if (config.has_join_positions()) { return std::make_unique(config.join_positions()); } throw std::runtime_error( "StageConfig did not contain a recognized stage configuration."); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/stage_factory.h ================================================ #pragma once #include #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" namespace lczero { namespace training { std::unique_ptr CreateStage(const StageConfig& config); } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/stage_factory_test.cc ================================================ #include "loader/stages/stage_factory.h" #include #include namespace lczero { namespace training { TEST(StageFactoryTest, CreatesFilePathProviderStage) { StageConfig config; config.mutable_file_path_provider()->set_directory("."); auto stage = CreateStage(config); ASSERT_NE(stage, nullptr); EXPECT_NE(stage->GetOutput(), nullptr); } TEST(StageFactoryTest, ThrowsWhenNoStageConfigSet) { StageConfig config; EXPECT_THROW(CreateStage(config), std::runtime_error); } TEST(StageFactoryTest, ThrowsWhenMultipleStageConfigsSet) { StageConfig config; config.mutable_file_path_provider()->set_directory("."); config.mutable_tensor_generator(); EXPECT_THROW(CreateStage(config), std::runtime_error); } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/tensor_generator.cc ================================================ // ABOUTME: Implementation of TensorGenerator stage for training pipeline. // ABOUTME: Converts V6TrainingData frames to batched tensors for training. #include "loader/stages/tensor_generator.h" #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/log/log.h" #include "loader/data_loader_metrics.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" namespace lczero { namespace training { TensorGenerator::TensorGenerator(const TensorGeneratorConfig& config) : SingleInputStage(config), SingleOutputStage(config.output()), batch_size_(config.batch_size()), thread_pool_(config.threads(), ThreadPoolOptions{}) { LOG(INFO) << "Initializing TensorGenerator with " << config.threads() << " threads, batch size " << config.batch_size(); // Initialize thread contexts but don't start worker threads yet. thread_contexts_.reserve(config.threads()); for (size_t i = 0; i < config.threads(); ++i) { thread_contexts_.push_back(std::make_unique()); } } TensorGenerator::~TensorGenerator() { Stop(); } void TensorGenerator::Start() { LOG(INFO) << "Starting TensorGenerator worker threads."; for (size_t i = 0; i < thread_contexts_.size(); ++i) { thread_pool_.Enqueue([this, i](std::stop_token stop_token) { Worker(stop_token, thread_contexts_[i].get()); }); } } void TensorGenerator::Stop() { if (thread_pool_.stop_token().stop_requested()) return; LOG(INFO) << "Stopping TensorGenerator."; thread_pool_.Shutdown(); output_queue()->Close(); LOG(INFO) << "TensorGenerator stopped."; } void TensorGenerator::Worker(std::stop_token stop_token, ThreadContext* context) { auto producer = output_queue()->CreateProducer(); std::vector batch; batch.reserve(batch_size_); try { while (true) { // Collect frames for a batch. batch.clear(); for (size_t i = 0; i < batch_size_; ++i) { LoadMetricPauser pauser(context->load_metric_updater); batch.push_back(input_queue()->Get(stop_token)); } // Convert batch to tensors. TensorTuple tensors = ConvertFramesToTensors(batch); { LoadMetricPauser pauser(context->load_metric_updater); producer.Put(std::move(tensors), stop_token); } } } catch (const QueueClosedException&) { LOG(INFO) << "TensorGenerator worker stopping, queue closed."; } catch (const QueueRequestCancelled&) { LOG(INFO) << "TensorGenerator worker stopping, request cancelled."; } } TensorTuple TensorGenerator::ConvertFramesToTensors( const std::vector& frames) { const size_t batch_size = frames.size(); constexpr size_t kNumPlanes = 112; constexpr size_t kNumPolicyMoves = 1858; constexpr size_t kNumValueTypes = 6; constexpr size_t kValuesPerType = 3; TensorTuple result; result.reserve(3); // Index 0: Input planes (batch_size, 112, 8, 8) auto planes_tensor = std::make_unique>( std::initializer_list{batch_size, kNumPlanes, 8, 8}); ProcessPlanes(frames, *planes_tensor); result.push_back(std::move(planes_tensor)); // Index 1: Probabilities (batch_size, 1858) auto probs_tensor = std::make_unique>( std::initializer_list{batch_size, kNumPolicyMoves}); for (size_t i = 0; i < batch_size; ++i) { auto probs_slice = probs_tensor->slice({static_cast(i)}); std::memcpy(probs_slice.data(), frames[i].probabilities, kNumPolicyMoves * sizeof(float)); } result.push_back(std::move(probs_tensor)); // Index 2: Values (batch_size, 6, 3) with [q, d, m] for each type. // [0]: result, [1]: best, [2]: played, [3]: orig, [4]: root, [5]: st auto values_tensor = std::make_unique>(std::initializer_list{ batch_size, kNumValueTypes, kValuesPerType}); for (size_t i = 0; i < batch_size; ++i) { const auto& frame = frames[i]; auto batch_slice = values_tensor->slice({static_cast(i)}); // Index 0: result [result_q, result_d, plies_left] auto result_slice = batch_slice.subspan(0 * kValuesPerType, kValuesPerType); result_slice[0] = frame.result_q; result_slice[1] = frame.result_d; result_slice[2] = frame.plies_left; // Index 1: best [best_q, best_d, best_m] auto best_slice = batch_slice.subspan(1 * kValuesPerType, kValuesPerType); best_slice[0] = frame.best_q; best_slice[1] = frame.best_d; best_slice[2] = frame.best_m; // Index 2: played [played_q, played_d, played_m] auto played_slice = batch_slice.subspan(2 * kValuesPerType, kValuesPerType); played_slice[0] = frame.played_q; played_slice[1] = frame.played_d; played_slice[2] = frame.played_m; // Index 3: orig [orig_q, orig_d, orig_m] (may be NaN) auto orig_slice = batch_slice.subspan(3 * kValuesPerType, kValuesPerType); orig_slice[0] = frame.orig_q; orig_slice[1] = frame.orig_d; orig_slice[2] = frame.orig_m; // Index 4: root [root_q, root_d, root_m] auto root_slice = batch_slice.subspan(4 * kValuesPerType, kValuesPerType); root_slice[0] = frame.root_q; root_slice[1] = frame.root_d; root_slice[2] = frame.root_m; // Index 5: st [q_st, d_st, NaN] auto st_slice = batch_slice.subspan(5 * kValuesPerType, kValuesPerType); st_slice[0] = frame.q_st; st_slice[1] = frame.d_st; st_slice[2] = std::numeric_limits::quiet_NaN(); } result.push_back(std::move(values_tensor)); return result; } void TensorGenerator::ProcessPlanes(const std::vector& frames, TypedTensor& planes_tensor) { const size_t batch_size = frames.size(); for (size_t i = 0; i < batch_size; ++i) { const auto& frame = frames[i]; auto batch_slice = planes_tensor.slice({static_cast(i)}); // Process first 104 planes from frame.planes (each uint64_t represents 64 // bits). for (ssize_t plane = 0; plane < 104; ++plane) { auto plane_slice = batch_slice.subspan(plane * 64, 64); uint64_t plane_bits = frame.planes[plane]; for (ssize_t square = 0; square < 64; ++square) { // XOR with 7 remaps the index within each byte from 0..7 to 7..0. plane_slice[square] = static_cast((plane_bits >> (square ^ 7)) & 1); } } // Add 8 additional planes for metadata (planes 104-111). const std::pair meta_planes[] = { {104, static_cast(frame.castling_us_ooo)}, {105, static_cast(frame.castling_us_oo)}, {106, static_cast(frame.castling_them_ooo)}, {107, static_cast(frame.castling_them_oo)}, {108, static_cast(frame.side_to_move_or_enpassant)}, {109, static_cast(frame.rule50_count) / 99.0f}, {110, 0.0f}, // All zeros (constant plane). {111, 1.0f}, // All ones (constant plane). }; for (const auto& [plane_num, value] : meta_planes) { auto plane_slice = batch_slice.subspan(plane_num * 64, 64); absl::c_fill(plane_slice, value); } } } StageMetricProto TensorGenerator::FlushMetrics() { StageMetricProto stage_metric; LoadMetricProto aggregated_load; aggregated_load.set_name("load"); for (const auto& context : thread_contexts_) { UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics()); } *stage_metric.add_load_metrics() = std::move(aggregated_load); *stage_metric.add_queue_metrics() = MetricsFromQueue("output", *output_queue()); return stage_metric; } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/tensor_generator.h ================================================ // ABOUTME: Stage that converts FrameType frames into tensor batches. // ABOUTME: Produces TrainingTensors with tensors for training pipeline. #pragma once #include #include #include #include #include #include "loader/data_loader.h" #include "loader/data_loader_metrics.h" #include "loader/frame_type.h" #include "loader/stages/stage.h" #include "proto/data_loader_config.pb.h" #include "proto/training_metrics.pb.h" #include "utils/queue.h" #include "utils/tensor.h" #include "utils/thread_pool.h" namespace lczero { namespace training { // Worker pool that converts FrameType frames into tensor batches. // Takes individual FrameType frames as input and outputs TensorTuple // containing batched tensors in the format required for training. class TensorGenerator : public SingleInputStage, public SingleOutputStage { public: using InputType = FrameType; using OutputType = TensorTuple; explicit TensorGenerator(const TensorGeneratorConfig& config); ~TensorGenerator(); void Start() override; void Stop() override; StageMetricProto FlushMetrics() override; private: struct ThreadContext { LoadMetricUpdater load_metric_updater; }; void Worker(std::stop_token stop_token, ThreadContext* context); TensorTuple ConvertFramesToTensors(const std::vector& frames); void ProcessPlanes(const std::vector& frames, TypedTensor& planes_tensor); size_t batch_size_; // thread_contexts_ must be declared before thread_pool_ to ensure // thread_pool_ is destroyed first (stopping threads before contexts). std::vector> thread_contexts_; ThreadPool thread_pool_; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/tensor_generator_test.cc ================================================ // ABOUTME: Unit tests for TensorGenerator stage in training pipeline. // ABOUTME: Tests tensor conversion, batching, and data format correctness. #include "loader/stages/tensor_generator.h" #include #include #include #include "gtest/gtest.h" #include "libs/lc0/src/trainingdata/trainingdata_v6.h" #include "utils/queue.h" #include "utils/tensor.h" namespace lczero { namespace training { namespace { template class PassthroughStage : public Stage { public: explicit PassthroughStage(Queue* queue) : queue_(queue) {} void Start() override {} void Stop() override {} StageMetricProto FlushMetrics() override { return StageMetricProto(); } QueueBase* GetOutput(std::string_view name = "") override { (void)name; return queue_; } void SetInputs(absl::Span inputs) override { if (!inputs.empty()) { throw std::runtime_error("PassthroughStage expects no inputs"); } } private: Queue* queue_; }; } // namespace class TensorGeneratorTest : public ::testing::Test { protected: void SetUp() override { input_queue_ = std::make_unique>(100); config_.set_batch_size(4); config_.set_threads(1); config_.mutable_output()->set_queue_capacity(10); } FrameType CreateTestFrame() { FrameType frame{}; std::memset(&frame, 0, sizeof(frame)); frame.version = 6; frame.input_format = 3; // Fill probabilities with test values. for (ssize_t i = 0; i < 1858; ++i) { frame.probabilities[i] = static_cast(i) / 1858.0f; } // Fill planes with test pattern. for (ssize_t i = 0; i < 104; ++i) { frame.planes[i] = 0x0F0F0F0F0F0F0F0FULL + i; // Test pattern } // Set castling rights. frame.castling_us_ooo = 1; frame.castling_us_oo = 0; frame.castling_them_ooo = 1; frame.castling_them_oo = 1; // Set other fields. frame.side_to_move_or_enpassant = 1; frame.rule50_count = 50; // Set Q and D values. frame.result_q = 0.5f; frame.result_d = 0.2f; frame.best_q = 0.3f; frame.best_d = 0.1f; frame.best_m = 42.5f; frame.plies_left = 42.5f; return frame; } void VerifyTensorTuple(const TensorTuple& tensors, const std::vector& frames) { const size_t batch_size = frames.size(); // Verify tuple has 3 elements ASSERT_EQ(tensors.size(), 3); // Verify input tensor: (batch_size, 112, 8, 8) const auto* planes_tensor = dynamic_cast*>(tensors[0].get()); ASSERT_NE(planes_tensor, nullptr); EXPECT_EQ(planes_tensor->shape().size(), 4); EXPECT_EQ(planes_tensor->shape()[0], batch_size); EXPECT_EQ(planes_tensor->shape()[1], 112); EXPECT_EQ(planes_tensor->shape()[2], 8); EXPECT_EQ(planes_tensor->shape()[3], 8); // Verify probabilities tensor: (batch_size, 1858) const auto* probs_tensor = dynamic_cast*>(tensors[1].get()); ASSERT_NE(probs_tensor, nullptr); EXPECT_EQ(probs_tensor->shape().size(), 2); EXPECT_EQ(probs_tensor->shape()[0], batch_size); EXPECT_EQ(probs_tensor->shape()[1], 1858); // Verify values tensor: (batch_size, 6, 3) const auto* values_tensor = dynamic_cast*>(tensors[2].get()); ASSERT_NE(values_tensor, nullptr); EXPECT_EQ(values_tensor->shape().size(), 3); EXPECT_EQ(values_tensor->shape()[0], batch_size); EXPECT_EQ(values_tensor->shape()[1], 6); EXPECT_EQ(values_tensor->shape()[2], 3); } void VerifyTensorData(const TensorTuple& tensors, const std::vector& frames) { const size_t batch_size = frames.size(); const auto* planes_tensor = dynamic_cast*>(tensors[0].get()); const auto* probs_tensor = dynamic_cast*>(tensors[1].get()); const auto* values_tensor = dynamic_cast*>(tensors[2].get()); for (size_t i = 0; i < batch_size; ++i) { const auto& frame = frames[i]; // Verify probabilities data. auto probs_slice = probs_tensor->slice({static_cast(i)}); for (ssize_t j = 0; j < 1858; ++j) { EXPECT_FLOAT_EQ(probs_slice[j], frame.probabilities[j]); } // Verify values tensor [batch, 6, 3] with raw q/d/m values // Index 0: result (q=0.5, d=0.2, m=42.5) auto values_slice = values_tensor->slice({static_cast(i)}); EXPECT_FLOAT_EQ(values_slice[0 * 3 + 0], 0.5f); // result_q EXPECT_FLOAT_EQ(values_slice[0 * 3 + 1], 0.2f); // result_d EXPECT_FLOAT_EQ(values_slice[0 * 3 + 2], 42.5f); // result_m // Index 1: best (q=0.3, d=0.1, m=42.5) EXPECT_FLOAT_EQ(values_slice[1 * 3 + 0], 0.3f); // best_q EXPECT_FLOAT_EQ(values_slice[1 * 3 + 1], 0.1f); // best_d EXPECT_FLOAT_EQ(values_slice[1 * 3 + 2], 42.5f); // best_m // Verify planes data - check first few planes and meta planes. auto planes_slice = planes_tensor->slice({static_cast(i)}); // Check first plane (plane 0). uint64_t expected_plane_0 = 0x0F0F0F0F0F0F0F0FULL; for (ssize_t square = 0; square < 64; ++square) { float expected = static_cast((expected_plane_0 >> (63 - square)) & 1); EXPECT_FLOAT_EQ(planes_slice[square], expected); } // Check meta planes. // Plane 104: castling_us_ooo = 1 for (ssize_t square = 104 * 64; square < 105 * 64; ++square) { EXPECT_FLOAT_EQ(planes_slice[square], 1.0f); } // Plane 105: castling_us_oo = 0 for (ssize_t square = 105 * 64; square < 106 * 64; ++square) { EXPECT_FLOAT_EQ(planes_slice[square], 0.0f); } // Plane 109: rule50_count = 50, should be 50/99 for (ssize_t square = 109 * 64; square < 110 * 64; ++square) { EXPECT_FLOAT_EQ(planes_slice[square], 50.0f / 99.0f); } // Plane 110: all zeros for (ssize_t square = 110 * 64; square < 111 * 64; ++square) { EXPECT_FLOAT_EQ(planes_slice[square], 0.0f); } // Plane 111: all ones for (ssize_t square = 111 * 64; square < 112 * 64; ++square) { EXPECT_FLOAT_EQ(planes_slice[square], 1.0f); } } } std::unique_ptr> input_queue_; TensorGeneratorConfig config_; }; TEST_F(TensorGeneratorTest, GeneratesCorrectTensorShapes) { TensorGenerator generator(config_); generator.SetInputs({input_queue_.get()}); generator.Start(); auto producer = input_queue_->CreateProducer(); std::vector frames; for (size_t i = 0; i < config_.batch_size(); ++i) { frames.push_back(CreateTestFrame()); producer.Put(frames.back()); } producer.Close(); auto tensors = generator.output_queue()->Get(); VerifyTensorTuple(tensors, frames); } TEST_F(TensorGeneratorTest, GeneratesCorrectTensorData) { TensorGenerator generator(config_); generator.SetInputs({input_queue_.get()}); generator.Start(); auto producer = input_queue_->CreateProducer(); std::vector frames; for (size_t i = 0; i < config_.batch_size(); ++i) { frames.push_back(CreateTestFrame()); producer.Put(frames.back()); } producer.Close(); auto tensors = generator.output_queue()->Get(); VerifyTensorTuple(tensors, frames); VerifyTensorData(tensors, frames); } TEST_F(TensorGeneratorTest, HandlesMultipleBatches) { TensorGenerator generator(config_); generator.SetInputs({input_queue_.get()}); generator.Start(); auto producer = input_queue_->CreateProducer(); // Send two full batches. std::vector all_frames; for (ssize_t batch = 0; batch < 2; ++batch) { for (size_t i = 0; i < config_.batch_size(); ++i) { auto frame = CreateTestFrame(); frame.version = batch * 1000 + i; // Unique version for each frame all_frames.push_back(frame); producer.Put(frame); } } producer.Close(); // Get first batch. auto tensors1 = generator.output_queue()->Get(); std::vector batch1_frames( all_frames.begin(), all_frames.begin() + config_.batch_size()); VerifyTensorTuple(tensors1, batch1_frames); // Get second batch. auto tensors2 = generator.output_queue()->Get(); std::vector batch2_frames( all_frames.begin() + config_.batch_size(), all_frames.end()); VerifyTensorTuple(tensors2, batch2_frames); // No more batches should be available. EXPECT_THROW(generator.output_queue()->Get(), QueueClosedException); } TEST_F(TensorGeneratorTest, HandlesDifferentBatchSizes) { config_.set_batch_size(2); TensorGenerator generator(config_); generator.SetInputs({input_queue_.get()}); generator.Start(); auto producer = input_queue_->CreateProducer(); std::vector frames; for (size_t i = 0; i < config_.batch_size(); ++i) { frames.push_back(CreateTestFrame()); producer.Put(frames.back()); } producer.Close(); auto tensors = generator.output_queue()->Get(); VerifyTensorTuple(tensors, frames); } TEST_F(TensorGeneratorTest, HandlesEmptyInput) { TensorGenerator generator(config_); generator.SetInputs({input_queue_.get()}); generator.Start(); // Close input queue without sending data. input_queue_->Close(); // Should not output any tensors. EXPECT_THROW(generator.output_queue()->Get(), QueueClosedException); } TEST_F(TensorGeneratorTest, VerifiesPlanesConversion) { config_.set_batch_size(1); TensorGenerator generator(config_); generator.SetInputs({input_queue_.get()}); generator.Start(); auto producer = input_queue_->CreateProducer(); FrameType frame = CreateTestFrame(); // Set specific bit pattern for plane 0. frame.planes[0] = 0xAAAAAAAAAAAAAAAAULL; // Alternating bits // Set specific values for meta planes. frame.castling_us_ooo = 1; frame.castling_us_oo = 0; frame.rule50_count = 75; producer.Put(frame); producer.Close(); auto tensors = generator.output_queue()->Get(); const auto* planes_tensor = dynamic_cast*>(tensors[0].get()); auto planes_slice = planes_tensor->slice({0}); // Verify plane 0 bit conversion. for (ssize_t square = 0; square < 64; ++square) { float expected = static_cast((0xAAAAAAAAAAAAAAAAULL >> (63 - square)) & 1); EXPECT_FLOAT_EQ(planes_slice[square], expected) << "Mismatch at square " << square; } // Verify rule50_count conversion: 75/99. for (ssize_t square = 109 * 64; square < 110 * 64; ++square) { EXPECT_FLOAT_EQ(planes_slice[square], 75.0f / 99.0f); } } TEST_F(TensorGeneratorTest, VerifiesQDConversion) { config_.set_batch_size(1); TensorGenerator generator(config_); generator.SetInputs({input_queue_.get()}); generator.Start(); auto producer = input_queue_->CreateProducer(); FrameType frame = CreateTestFrame(); // Test specific Q/D values. frame.result_q = 0.4f; frame.result_d = 0.3f; frame.best_q = -0.2f; frame.best_d = 0.1f; producer.Put(frame); producer.Close(); auto tensors = generator.output_queue()->Get(); const auto* values_tensor = dynamic_cast*>(tensors[2].get()); auto values_slice = values_tensor->slice({0}); // Verify result values: q=0.4, d=0.3 (raw values, no WDL conversion) EXPECT_FLOAT_EQ(values_slice[0 * 3 + 0], 0.4f); // result_q EXPECT_FLOAT_EQ(values_slice[0 * 3 + 1], 0.3f); // result_d // Verify best values: q=-0.2, d=0.1 (raw values, no WDL conversion) EXPECT_FLOAT_EQ(values_slice[1 * 3 + 0], -0.2f); // best_q EXPECT_FLOAT_EQ(values_slice[1 * 3 + 1], 0.1f); // best_d } } // namespace training } // namespace lczero ================================================ FILE: csrc/loader/stages/training_chunk.h ================================================ #pragma once #include #include #include #include #include "loader/frame_type.h" namespace lczero { namespace training { struct TrainingChunk { std::vector frames; std::string sort_key; size_t index_within_sort_key = 0; size_t global_index = 0; uint32_t use_count = 0; }; struct CacheRequest { size_t global_index = 0; uint16_t next_use = 0; std::vector items; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/tools/dump_chunk_main.cc ================================================ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "trainingdata/trainingdata_v6.h" #include "utils/training_data_printer.h" ABSL_FLAG(std::string, chunk_path, "", "Path to the chunk file (.gz) to dump."); ABSL_FLAG(int64_t, max_entries, -1, "Maximum number of entries to print. -1 prints all entries."); ABSL_FLAG(int64_t, float_values_per_line, 8, "Number of floating point values per output line."); ABSL_FLAG(int64_t, plane_values_per_line, 4, "Number of plane values per output line."); namespace lczero { namespace training { namespace { using ::lczero::training::FrameType; using ::lczero::training::PrintTrainingDataEntry; void DumpChunk(const std::string& path, int64_t max_entries, int64_t float_per_line, int64_t plane_per_line) { gzFile file = gzopen(path.c_str(), "rb"); if (file == nullptr) { LOG(FATAL) << "Failed to open chunk file: " << path; } size_t index = 0; while (true) { FrameType entry; const int bytes_read = gzread(file, &entry, sizeof(entry)); if (bytes_read == 0) { break; } if (bytes_read < 0) { int errnum = 0; const char* error_message = gzerror(file, &errnum); gzclose(file); LOG(FATAL) << "Error while reading chunk: " << error_message; } if (bytes_read != sizeof(entry)) { gzclose(file); LOG(FATAL) << "Unexpected chunk size. Expected " << sizeof(entry) << " bytes, got " << bytes_read << "."; } const std::string header = absl::StrFormat("Entry %zu:", index); PrintTrainingDataEntry(entry, header, float_per_line, plane_per_line); ++index; if (max_entries >= 0 && static_cast(index) >= max_entries) { break; } } gzclose(file); LOG(INFO) << "Printed " << index << " entries."; } } // namespace } // namespace training } // namespace lczero int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::InitializeLog(); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); const std::string chunk_path = absl::GetFlag(FLAGS_chunk_path); if (chunk_path.empty()) { LOG(FATAL) << "--chunk_path flag is required."; } const int64_t max_entries = absl::GetFlag(FLAGS_max_entries); const int64_t float_per_line = absl::GetFlag(FLAGS_float_values_per_line); const int64_t plane_per_line = absl::GetFlag(FLAGS_plane_values_per_line); lczero::training::DumpChunk(chunk_path, max_entries, float_per_line, plane_per_line); return 0; } ================================================ FILE: csrc/tools/filter_chunks_main.cc ================================================ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "loader/chunk_source/chunk_source.h" #include "loader/chunk_source/tar_chunk_source.h" #include "trainingdata/trainingdata_v6.h" ABSL_FLAG(std::string, input_dir, ".", "Directory to scan for .tar files."); ABSL_FLAG(std::string, output_dir, ".", "Directory where matching chunks will be written."); ABSL_FLAG(std::string, plane_values, "", "Comma separated list of plane values (decimal or hex)."); namespace { namespace fs = std::filesystem; using ::lczero::training::ChunkSource; using ::lczero::training::ChunkSourceLoaderConfig; using ::lczero::training::FrameType; using ::lczero::training::TarChunkSource; std::vector CollectTarFiles(const fs::path& directory) { std::vector files; for (const auto& entry : fs::directory_iterator(directory)) { const fs::path& path = entry.path(); if (entry.is_regular_file() && path.extension() == ".tar") { files.push_back(path); } } absl::c_sort(files, [](const fs::path& lhs, const fs::path& rhs) { return lhs.filename() < rhs.filename(); }); return files; } std::vector ParsePlaneValues(absl::string_view value_list) { std::vector result; if (value_list.empty()) { LOG(FATAL) << "--plane_values flag must not be empty."; } for (absl::string_view token : absl::StrSplit(value_list, ',', absl::SkipWhitespace())) { token = absl::StripAsciiWhitespace(token); if (token.empty()) continue; uint64_t value = 0; if (absl::StartsWithIgnoreCase(token, "0x")) { const absl::string_view hex_part = token.substr(2); if (hex_part.empty() || !absl::SimpleHexAtoi(hex_part, &value)) { LOG(FATAL) << "Invalid hex plane value: " << token; } } else if (!absl::SimpleAtoi(token, &value)) { LOG(FATAL) << "Invalid decimal plane value: " << token; } result.push_back(value); } if (result.empty()) { LOG(FATAL) << "No plane values were parsed."; } return result; } bool PlanesMatch(const FrameType& entry, absl::Span expected) { if (expected.size() > std::size(entry.planes)) return false; const size_t bytes = expected.size() * sizeof(uint64_t); return std::memcmp(entry.planes, expected.data(), bytes) == 0; } std::optional FindMatchingFrameIndex( const std::vector& chunk, absl::Span expected) { for (size_t frame = 0; frame < chunk.size(); ++frame) { if (PlanesMatch(chunk[frame], expected)) return frame; } return std::nullopt; } void WriteChunk(const fs::path& output_dir, absl::string_view base_name, size_t index, size_t frame_index, const std::vector& chunk) { fs::create_directories(output_dir); const fs::path output_path = output_dir / absl::StrCat(base_name, "_", index, "_", frame_index, ".gz"); gzFile file = gzopen(output_path.string().c_str(), "wb"); if (file == nullptr) { LOG(FATAL) << "Failed to open output file: " << output_path.string(); } size_t remaining = chunk.size() * sizeof(FrameType); const char* data = reinterpret_cast(chunk.data()); while (remaining > 0) { const unsigned int to_write = static_cast( std::min(remaining, std::numeric_limits::max())); const int written = gzwrite(file, data, to_write); if (written == 0) { int errnum = 0; const char* error_message = gzerror(file, &errnum); gzclose(file); LOG(FATAL) << "Failed to write chunk: " << error_message; } data += written; remaining -= static_cast(written); } if (gzclose(file) != Z_OK) { LOG(FATAL) << "Failed to close output file: " << output_path.string(); } LOG(INFO) << "Wrote matching chunk to " << output_path.string(); } void ProcessTar(const fs::path& tar_path, const fs::path& output_dir, absl::Span expected_planes) { std::unique_ptr source = std::make_unique( tar_path, ChunkSourceLoaderConfig::V6TrainingData); const std::string base_name = tar_path.stem().string(); size_t written_count = 0; for (size_t index = 0, total = source->GetChunkCount(); index < total; ++index) { const std::optional> chunk = source->GetChunkData(index); if (!chunk) { LOG(WARNING) << "Skipping unreadable chunk " << index << " in " << tar_path.string(); continue; } const std::optional match = FindMatchingFrameIndex(*chunk, expected_planes); if (!match) continue; WriteChunk(output_dir, base_name, index, *match, *chunk); ++written_count; } LOG(INFO) << "Finished processing " << tar_path.string() << ": wrote " << written_count << " chunk(s)."; } } // namespace int main(int argc, char** argv) { absl::InitializeLog(); absl::ParseCommandLine(argc, argv); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); const fs::path input_dir(absl::GetFlag(FLAGS_input_dir)); const fs::path output_dir(absl::GetFlag(FLAGS_output_dir)); const std::string plane_values = absl::GetFlag(FLAGS_plane_values); if (!fs::exists(input_dir) || !fs::is_directory(input_dir)) { LOG(FATAL) << "Input directory does not exist: " << input_dir.string(); } const std::vector expected_planes = ParsePlaneValues(plane_values); fs::create_directories(output_dir); const std::vector tar_files = CollectTarFiles(input_dir); const absl::Span expected_span(expected_planes); std::vector workers; workers.reserve(tar_files.size()); for (const auto& tar_path : tar_files) { workers.emplace_back([tar_path, output_dir, expected_span]() { LOG(INFO) << "Processing tar file: " << tar_path.string(); try { ProcessTar(tar_path, output_dir, expected_span); } catch (const std::exception& e) { LOG(WARNING) << "Failed to process tar file " << tar_path << ": " << e.what(); } }); } for (auto& worker : workers) { worker.join(); } return 0; } ================================================ FILE: csrc/tools/position_weight_stats_main.cc ================================================ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "loader/chunk_source/chunk_source.h" #include "loader/chunk_source/tar_chunk_source.h" #include "loader/stages/position_sampling.h" #include "proto/data_loader_config.pb.h" #include "trainingdata/trainingdata_v6.h" #include "utils/training_data_printer.h" ABSL_FLAG(std::string, input_dir, "", "Directory to scan for .tar files."); ABSL_FLAG(float, q_weight, 6.0, "Value for diff_focus_q_weight."); ABSL_FLAG(float, pol_scale, 3.5, "Value for diff_focus_pol_scale."); namespace { namespace fs = std::filesystem; using ::lczero::training::ChunkSource; using ::lczero::training::ChunkSourceLoaderConfig; using ::lczero::training::ComputePositionSamplingWeight; using ::lczero::training::FrameType; using ::lczero::training::PositionSamplingConfig; using ::lczero::training::PrintTrainingDataEntry; using ::lczero::training::TarChunkSource; struct WeightedPosition { FrameType data; float weight; }; std::vector CollectTarFiles(const fs::path& directory) { std::vector files; for (const auto& entry : fs::directory_iterator(directory)) { const fs::path& path = entry.path(); if (entry.is_regular_file() && path.extension() == ".tar") { files.push_back(path); } } absl::c_sort(files, [](const fs::path& lhs, const fs::path& rhs) { return lhs.filename() < rhs.filename(); }); return files; } std::vector CollectWeights(const fs::path& tar_path, const PositionSamplingConfig& config, WeightedPosition* max_weighted) { std::vector weights; std::unique_ptr source = std::make_unique( tar_path, ChunkSourceLoaderConfig::V6TrainingData); const size_t total = source->GetChunkCount(); for (size_t index = 0; index < total; ++index) { if (index % 1000 == 0) { LOG(INFO) << absl::StreamFormat(" Progress: %zu/%zu chunks (%.1f%%)", index, total, 100.0 * index / total); } const std::optional> chunk = source->GetChunkData(index); if (!chunk) { LOG(WARNING) << "Skipping unreadable chunk " << index << " in " << tar_path.string(); continue; } if (chunk->empty()) continue; for (const auto& entry : *chunk) { const float weight = ComputePositionSamplingWeight(entry, config); weights.push_back(weight); if (max_weighted && weight > max_weighted->weight) { max_weighted->data = entry; max_weighted->weight = weight; } } } return weights; } void PrintHistogram(const std::vector& sorted_weights) { if (sorted_weights.empty()) return; constexpr int kBuckets = 50; constexpr int kMaxWidth = 60; const float min_val = sorted_weights.front(); const float max_val = sorted_weights.back(); const float range = max_val - min_val; if (range == 0.0f) { LOG(INFO) << "\nHistogram: All weights are identical (" << min_val << ")"; return; } std::vector buckets(kBuckets, 0); for (float weight : sorted_weights) { int bucket = static_cast((weight - min_val) / range * (kBuckets - 1)); bucket = std::clamp(bucket, 0, kBuckets - 1); ++buckets[bucket]; } const int max_count = *absl::c_max_element(buckets); if (max_count == 0) return; LOG(INFO) << "\nHistogram:"; for (int bucket = 0; bucket < kBuckets; ++bucket) { if (buckets[bucket] == 0) continue; const float bucket_start = min_val + range * bucket / kBuckets; const float bucket_end = min_val + range * (bucket + 1) / kBuckets; const int width = (buckets[bucket] * kMaxWidth + max_count / 2) / max_count; std::string bar; for (int i = 0; i < width; ++i) bar += "█"; LOG(INFO) << absl::StreamFormat("[%.4f-%.4f) │%s (%d)", bucket_start, bucket_end, bar, buckets[bucket]); } } void PrintPercentiles(const std::vector& sorted_weights) { if (sorted_weights.empty()) return; LOG(INFO) << "\nPercentiles:"; for (int p = 0; p <= 100; ++p) { const size_t idx = (sorted_weights.size() - 1) * p / 100; LOG(INFO) << absl::StreamFormat(" %3d%%: %.6f", p, sorted_weights[idx]); } } void PrintStatistics(const std::vector& weights) { if (weights.empty()) { LOG(INFO) << "No weights collected."; return; } const double sum = std::accumulate(weights.begin(), weights.end(), 0.0); const double mean = sum / weights.size(); LOG(INFO) << "\nStatistics:"; LOG(INFO) << " Total positions: " << weights.size(); LOG(INFO) << absl::StreamFormat(" Mean: %.6f", mean); } } // namespace int main(int argc, char** argv) { absl::InitializeLog(); absl::ParseCommandLine(argc, argv); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); const std::string input_dir_str = absl::GetFlag(FLAGS_input_dir); const float q_weight = absl::GetFlag(FLAGS_q_weight); const float pol_scale = absl::GetFlag(FLAGS_pol_scale); if (input_dir_str.empty()) { LOG(FATAL) << "--input_dir must be specified."; } const fs::path input_dir(input_dir_str); if (!fs::exists(input_dir) || !fs::is_directory(input_dir)) { LOG(FATAL) << "Input directory does not exist: " << input_dir.string(); } PositionSamplingConfig config; config.set_diff_focus_q_weight(q_weight); config.set_diff_focus_pol_scale(pol_scale); const std::vector tar_files = CollectTarFiles(input_dir); LOG(INFO) << "Found " << tar_files.size() << " tar file(s)."; WeightedPosition max_weighted = {{}, 0.0f}; std::vector all_weights; for (const auto& tar_path : tar_files) { LOG(INFO) << "Processing: " << tar_path.string(); std::vector weights = CollectWeights(tar_path, config, &max_weighted); all_weights.insert(all_weights.end(), weights.begin(), weights.end()); LOG(INFO) << " Collected " << weights.size() << " position(s)."; } absl::c_sort(all_weights); PrintStatistics(all_weights); PrintPercentiles(all_weights); PrintHistogram(all_weights); if (max_weighted.weight > 0.0f) { const std::string header = absl::StrFormat( "\nPosition with highest weight (%.6f):", max_weighted.weight); PrintTrainingDataEntry(max_weighted.data, header, 8, 4); } return 0; } ================================================ FILE: csrc/tools/rescore_chunk_main.cc ================================================ #include #include #include #include #include #include #include #include #include #include "chess/board.h" #include "proto/data_loader_config.pb.h" #include "syzygy/syzygy.h" #include "trainingdata/reader.h" #include "trainingdata/rescorer.h" #include "trainingdata/trainingdata_v6.h" #include "trainingdata/writer.h" #include "utils/exception.h" ABSL_FLAG(std::string, chunk_path, "", "Path to the chunk file (.gz) that should be rescored."); ABSL_FLAG(std::string, syzygy_paths, "", "Comma-separated list of Syzygy tablebase directories."); ABSL_FLAG(double, dist_temp, 1.0, "Policy temperature applied during rescoring."); ABSL_FLAG(double, dist_offset, 0.0, "Policy offset applied during rescoring."); ABSL_FLAG(double, dtz_boost, 0.0, "DTZ boost applied during policy adjustments."); ABSL_FLAG(int, new_input_format, -1, "Optional conversion target for input format (-1 keeps original)."); ABSL_FLAG( double, deblunder_threshold, -1.0, "Threshold for policy deblundering adjustments (negative to disable)."); ABSL_FLAG( double, deblunder_width, -1.0, "Width controlling smoothing around threshold (negative to disable)."); namespace { namespace fs = std::filesystem; using ::lczero::training::ChunkRescorerConfig; std::vector ReadChunkFrames(const fs::path& path) { std::vector frames; lczero::TrainingDataReader reader(path.string()); lczero::V6TrainingData frame; while (reader.ReadChunk(&frame)) { frames.push_back(frame); } return frames; } void WriteChunkFrames(const fs::path& path, const std::vector& frames) { lczero::TrainingDataWriter writer(path.string()); for (const auto& frame : frames) { writer.WriteChunk(frame); } writer.Finalize(); } fs::path BuildOutputPath(const fs::path& input_path) { fs::path directory = input_path.parent_path(); fs::path stem = input_path.stem(); fs::path filename = stem; filename += "_rescored.gz"; return directory / filename; } } // namespace int main(int argc, char** argv) { absl::InitializeLog(); absl::ParseCommandLine(argc, argv); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); const std::string chunk_path_flag = absl::GetFlag(FLAGS_chunk_path); if (chunk_path_flag.empty()) { LOG(FATAL) << "--chunk_path flag is required."; } const fs::path chunk_path(chunk_path_flag); ChunkRescorerConfig config; const std::string syzygy_paths_flag = absl::GetFlag(FLAGS_syzygy_paths); if (!syzygy_paths_flag.empty()) { config.set_syzygy_paths(syzygy_paths_flag); } config.set_dist_temp(static_cast(absl::GetFlag(FLAGS_dist_temp))); config.set_dist_offset(static_cast(absl::GetFlag(FLAGS_dist_offset))); config.set_dtz_boost(static_cast(absl::GetFlag(FLAGS_dtz_boost))); config.set_new_input_format( static_cast(absl::GetFlag(FLAGS_new_input_format))); const double deblunder_threshold_flag = absl::GetFlag(FLAGS_deblunder_threshold); const double deblunder_width_flag = absl::GetFlag(FLAGS_deblunder_width); if (deblunder_threshold_flag >= 0.0 && deblunder_width_flag >= 0.0) { config.set_deblunder_threshold( static_cast(deblunder_threshold_flag)); config.set_deblunder_width(static_cast(deblunder_width_flag)); } else if (deblunder_threshold_flag >= 0.0 || deblunder_width_flag >= 0.0) { LOG(FATAL) << "Both --deblunder_threshold and --deblunder_width must be " << "set to non-negative values together."; } LOG(INFO) << "Reading chunk from " << chunk_path.string(); std::vector frames; try { frames = ReadChunkFrames(chunk_path); } catch (const lczero::Exception& exception) { LOG(FATAL) << "Failed to read chunk: " << exception.what(); } LOG(INFO) << "Loaded " << frames.size() << " frame(s) from chunk."; if (frames.empty()) { LOG(WARNING) << "Chunk contains no frames; writing empty output."; try { WriteChunkFrames(BuildOutputPath(chunk_path), frames); } catch (const lczero::Exception& exception) { LOG(FATAL) << "Failed to write rescored chunk: " << exception.what(); } return 0; } lczero::InitializeMagicBitboards(); if (config.has_deblunder_threshold() && config.has_deblunder_width()) { lczero::RescorerDeblunderSetup(config.deblunder_threshold(), config.deblunder_width()); } lczero::SyzygyTablebase tablebase; if (!config.syzygy_paths().empty()) { LOG(INFO) << "Initializing Syzygy tablebases from '" << config.syzygy_paths() << "'."; const std::string syzygy_paths(config.syzygy_paths()); if (!tablebase.init(syzygy_paths)) { LOG(WARNING) << "Failed to initialize Syzygy tablebases."; } } LOG(INFO) << "Rescoring chunk with dist_temp=" << config.dist_temp() << ", dist_offset=" << config.dist_offset() << ", dtz_boost=" << config.dtz_boost() << ", new_input_format=" << config.new_input_format() << "."; try { frames = lczero::RescoreTrainingData( std::move(frames), &tablebase, config.dist_temp(), config.dist_offset(), config.dtz_boost(), config.new_input_format()); } catch (const lczero::Exception& exception) { LOG(FATAL) << "Failed to rescore chunk: " << exception.what(); } const fs::path output_path = BuildOutputPath(chunk_path); LOG(INFO) << "Writing rescored chunk to " << output_path.string(); try { WriteChunkFrames(output_path, frames); } catch (const lczero::Exception& exception) { LOG(FATAL) << "Failed to write rescored chunk: " << exception.what(); } LOG(INFO) << "Completed rescoring of chunk."; return 0; } ================================================ FILE: csrc/tools/result_distribution_main.cc ================================================ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "loader/chunk_source/tar_chunk_source.h" #include "trainingdata/trainingdata_v6.h" ABSL_FLAG(std::string, output_csv, "", "Destination CSV file. Writes to stdout if empty."); ABSL_FLAG(int, num_threads, 0, "Number of worker threads. Defaults to hardware concurrency."); namespace { namespace fs = std::filesystem; using ::lczero::training::ChunkSourceLoaderConfig; using ::lczero::training::FrameType; using ::lczero::training::TarChunkSource; enum class ChunkResult { kWin, kDraw, kLoss }; struct ResultCounts { uint64_t wins = 0; uint64_t draws = 0; uint64_t losses = 0; }; class CsvWriter { public: CsvWriter(std::ostream* output, absl::Mutex* mutex) : output_(output), mutex_(mutex) {} void Write(absl::string_view basename, const ResultCounts& counts) const { absl::MutexLock lock(mutex_); *output_ << basename << ',' << counts.wins << ',' << counts.draws << ',' << counts.losses << '\n'; output_->flush(); } private: std::ostream* output_; absl::Mutex* mutex_; }; std::ostream& SelectOutput(const fs::path& output_path, std::ofstream& file_stream) { if (output_path.empty()) return std::cout; file_stream.open(output_path, std::ios::out | std::ios::trunc); if (!file_stream) { LOG(FATAL) << "Failed to open output file: " << output_path.string(); } return file_stream; } constexpr float kFloatTolerance = 1e-6f; std::optional DetermineChunkResult(absl::string_view chunk_payload, size_t chunk_index, const fs::path& tar_path) { if (chunk_payload.size() < sizeof(FrameType)) { LOG(WARNING) << "Chunk " << chunk_index << " in " << tar_path.string() << " is too small."; return std::nullopt; } FrameType frame; std::memcpy(&frame, chunk_payload.data(), sizeof(frame)); if (std::fabs(frame.result_d - 1.0f) <= kFloatTolerance) { return ChunkResult::kDraw; } if (std::fabs(frame.result_d) > kFloatTolerance) { LOG(WARNING) << "Chunk " << chunk_index << " in " << tar_path.string() << " has unexpected result_d=" << frame.result_d << '.'; return std::nullopt; } const bool side_to_move = frame.side_to_move_or_enpassant != 0; if (std::fabs(frame.result_q - 1.0f) <= kFloatTolerance) { return side_to_move ? ChunkResult::kLoss : ChunkResult::kWin; } if (std::fabs(frame.result_q + 1.0f) <= kFloatTolerance) { return side_to_move ? ChunkResult::kWin : ChunkResult::kLoss; } LOG(WARNING) << "Chunk " << chunk_index << " in " << tar_path.string() << " has unexpected result_q=" << frame.result_q << '.'; return std::nullopt; } ResultCounts CountResultsInTar(const fs::path& tar_path) { ResultCounts counts; TarChunkSource source(tar_path, ChunkSourceLoaderConfig::V6TrainingData); const size_t chunk_count = source.GetChunkCount(); for (size_t index = 0; index < chunk_count; ++index) { const std::optional chunk = source.GetChunkPrefix(index, sizeof(FrameType)); if (!chunk) { LOG(WARNING) << "Skipping unreadable chunk " << index << " in " << tar_path.string(); continue; } const std::optional result = DetermineChunkResult(*chunk, index, tar_path); if (!result) continue; switch (*result) { case ChunkResult::kWin: ++counts.wins; break; case ChunkResult::kDraw: ++counts.draws; break; case ChunkResult::kLoss: ++counts.losses; break; } } return counts; } } // namespace int main(int argc, char** argv) { absl::InitializeLog(); std::vector positional = absl::ParseCommandLine(argc, argv); absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo); if (positional.size() <= 1) { LOG(FATAL) << "Provide at least one .tar file as a positional argument."; } const fs::path output_path(absl::GetFlag(FLAGS_output_csv)); std::ofstream file_stream; std::ostream& output = SelectOutput(output_path, file_stream); absl::Mutex output_mutex; const CsvWriter writer(&output, &output_mutex); std::vector tar_files; tar_files.reserve(positional.size() - 1); for (size_t i = 1; i < positional.size(); ++i) { tar_files.emplace_back(positional[i]); } const int num_threads_flag = absl::GetFlag(FLAGS_num_threads); size_t worker_count = 0; if (num_threads_flag > 0) { worker_count = static_cast(num_threads_flag); } else { const unsigned int hw_threads = std::thread::hardware_concurrency(); worker_count = hw_threads > 0 ? static_cast(hw_threads) : 1; } worker_count = std::max(1, std::min(worker_count, tar_files.size())); std::atomic next_index(0); std::vector workers; workers.reserve(worker_count); for (size_t worker_id = 0; worker_id < worker_count; ++worker_id) { workers.emplace_back([&tar_files, &next_index, &writer]() { while (true) { const size_t index = next_index.fetch_add(1, std::memory_order_relaxed); if (index >= tar_files.size()) break; const fs::path& tar_path = tar_files[index]; LOG(INFO) << "Processing tar file: " << tar_path.string(); try { const ResultCounts counts = CountResultsInTar(tar_path); writer.Write(tar_path.filename().string(), counts); } catch (const std::exception& exception) { LOG(WARNING) << "Failed to process tar file " << tar_path.string() << ": " << exception.what(); } } }); } for (auto& worker : workers) { worker.join(); } return 0; } ================================================ FILE: csrc/tools/startpos_policy_distribution_main.cc ================================================ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "loader/chunk_source/chunk_source.h" #include "loader/chunk_source/tar_chunk_source.h" #include "trainingdata/trainingdata_v6.h" ABSL_FLAG(std::string, input_dir, ".", "Directory to scan for .tar files."); ABSL_FLAG(std::string, output_csv, "", "Destination CSV file. Writes to stdout if empty."); namespace { namespace fs = std::filesystem; using ::lczero::training::ChunkSource; using ::lczero::training::ChunkSourceLoaderConfig; using ::lczero::training::FrameType; using ::lczero::training::TarChunkSource; constexpr std::array kStartPositionPlanes = { 0x000000000000ff00ull, 0x0000000000000042ull, 0x0000000000000024ull, 0x0000000000000081ull, 0x0000000000000010ull, 0x0000000000000008ull, 0x00ff000000000000ull, 0x4200000000000000ull, 0x2400000000000000ull, 0x8100000000000000ull, 0x1000000000000000ull, 0x0800000000000000ull, 0x0000000000000000ull, 0x0000000000000000ull, 0x0000000000000000ull, 0x0000000000000000ull}; using PolicyProbe = std::pair; constexpr std::array kPolicyProbes = { {{378, "g2g4"}, {346, "f2f3"}, {34, "b1a3"}, {161, "g1h3"}, {403, "h2h4"}, {351, "f2f4"}, {234, "b2b4"}, {207, "a2a4"}, {288, "d2d3"}, {204, "a2a3"}, {259, "c2c3"}, {36, "b1c3"}, {400, "h2h3"}, {230, "b2b3"}, {322, "e2e4"}, {317, "e2e3"}, {374, "g2g3"}, {264, "c2c4"}, {159, "g1f3"}, {293, "d2d4"}}}; bool MatchesStartPosition(const FrameType& data) { return absl::c_equal( kStartPositionPlanes, absl::Span(data.planes, kStartPositionPlanes.size())); } std::vector CollectTarFiles(const fs::path& directory) { std::vector files; for (const auto& entry : fs::directory_iterator(directory)) { const fs::path& path = entry.path(); if (entry.is_regular_file() && path.extension() == ".tar") { files.push_back(path); } } absl::c_sort(files, [](const fs::path& lhs, const fs::path& rhs) { return lhs.filename() < rhs.filename(); }); return files; } void WriteHeader(std::ostream& output) { output << "file,index"; for (const auto& probe : kPolicyProbes) output << ',' << probe.second; output << '\n'; } void WriteRow(std::ostream& output, absl::string_view sort_key, size_t index, const FrameType& data) { output << sort_key << ',' << index; for (const auto& probe : kPolicyProbes) { output << ',' << data.probabilities[probe.first]; } output << '\n'; } void ProcessTarFile(const fs::path& tar_path, std::ostream& output) { std::unique_ptr source = std::make_unique( tar_path, ChunkSourceLoaderConfig::V6TrainingData); const std::string sort_key = source->GetChunkSortKey(); for (size_t i = 0, count = source->GetChunkCount(); i < count; ++i) { const std::optional> chunk = source->GetChunkData(i); if (!chunk || chunk->empty()) continue; const FrameType& entry = chunk->front(); if (!MatchesStartPosition(entry)) continue; WriteRow(output, sort_key, i, entry); } } std::ostream& SelectOutput(const fs::path& output_path, std::ofstream& file_stream) { if (output_path.empty()) return std::cout; file_stream.open(output_path, std::ios::out | std::ios::trunc); if (!file_stream) { LOG(FATAL) << "Failed to open output file: " << output_path.string(); } return file_stream; } } // namespace int main(int argc, char** argv) { absl::InitializeLog(); absl::ParseCommandLine(argc, argv); const fs::path input_dir(absl::GetFlag(FLAGS_input_dir)); const fs::path output_path(absl::GetFlag(FLAGS_output_csv)); if (!fs::is_directory(input_dir)) { LOG(FATAL) << "Input directory does not exist: " << input_dir.string(); } std::ofstream file_stream; std::ostream& output = SelectOutput(output_path, file_stream); WriteHeader(output); for (const auto& tar_path : CollectTarFiles(input_dir)) { LOG(INFO) << "Processing tar file: " << tar_path.string(); try { ProcessTarFile(tar_path, output); } catch (const std::exception& e) { LOG(WARNING) << "Failed to process tar file " << tar_path << ": " << e.what(); } } return 0; } ================================================ FILE: csrc/utils/gz.cc ================================================ #include "utils/gz.h" #include #include #include #include namespace lczero { namespace training { std::string GunzipBuffer(std::string_view buffer) { z_stream strm = {}; int ret = inflateInit2(&strm, 16 + MAX_WBITS); if (ret != Z_OK) { throw GunzipError("Failed to initialize zlib inflate"); } strm.avail_in = buffer.size(); strm.next_in = reinterpret_cast(const_cast(buffer.data())); constexpr size_t kChunkSize = 16384; std::string output; std::array temp_buffer; do { strm.avail_out = kChunkSize; strm.next_out = reinterpret_cast(temp_buffer.data()); ret = inflate(&strm, Z_NO_FLUSH); if (ret == Z_STREAM_ERROR || ret == Z_NEED_DICT || ret == Z_DATA_ERROR || ret == Z_MEM_ERROR) { inflateEnd(&strm); throw GunzipError("zlib inflate error"); } size_t bytes_written = kChunkSize - strm.avail_out; output.append(temp_buffer.begin(), temp_buffer.begin() + bytes_written); } while (strm.avail_out == 0); inflateEnd(&strm); if (ret != Z_STREAM_END) { throw GunzipError("Incomplete gzip decompression"); } return output; } } // namespace training } // namespace lczero ================================================ FILE: csrc/utils/gz.h ================================================ #pragma once #include #include #include namespace lczero { namespace training { class GunzipError : public std::runtime_error { public: using std::runtime_error::runtime_error; }; std::string GunzipBuffer(std::string_view buffer); } // namespace training } // namespace lczero ================================================ FILE: csrc/utils/metrics/exponential_aggregator.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include #include namespace lczero { // 1 second period is exact. Other periods are powers of two. enum class TimePeriod { kEmpty = -128, k1Millisecond = -10, k2Milliseconds, k4Milliseconds, k8Milliseconds, k16Milliseconds, k31Milliseconds, k63Milliseconds, k125Milliseconds, k250Milliseconds, k500Milliseconds, k1Second /* = 0 */, k2Seconds, k4Seconds, k8Seconds, k16Seconds, k32Seconds, k1Minute, k2Minutes, k4Minutes, k9Minutes, k17Minutes, k36Minutes, k1Hour, k2Hours, k5Hours, k9Hours, k18Hours, k36Hours, k3Days, k6Days, k12Days, k24Days, k49Days, k97Days, k194Days, k388Days, k777Days, k2Years, k4Years, k9Years, k17Years, /* = 29 */ kAllTime = 127 }; // ExponentialAggregator metrics over exponentially increasing time periods. // // The template parameter `Metric` can be used in two ways: // 1. Function-based: Pass Clear and UpdateFrom functions to constructor. // This is the new approach for protobuf-based metrics. // 2. Method-based: The metric type has Reset() and MergeFrom() methods. // This maintains backward compatibility with existing C++ struct metrics. // // For method-based metrics, the type must satisfy: // * It must have a `Reset()` method that clears its state. // * It must have a `MergeFrom(const Metric& other)` method to merge another // metric into itself (used for bucket-to-bucket merging and live ingestion). // * It must behave like a monoid (actually, unital magma is sufficient): // * Merging with a default-constructed (empty) metric is a no-op. // * The `MergeFrom` operation must be associative (actually, not really; // currently we always merge old to new). It does not need to be // commutative. // * Having a `std::swap(Metric&, Metric&)` method is also beneficial. template class ExponentialAggregator { public: using Duration = std::chrono::nanoseconds; using Clock = std::chrono::steady_clock; using ClearFn = std::function; using UpdateFromFn = std::function; static constexpr TimePeriod kResolution = Resolution; // Resets the aggregator, clearing all buckets and pending metrics. void Reset(Clock::time_point now = Clock::now()); // Ingests the passed metric into the pending bucket using MergeFrom + Reset. void RecordMetrics(Metric&& metric); // Returns the latest completed metrics bucket for the given time period and // duration since that period finished last time. If now is nullopt, it // excludes the time since the last metrics flush. std::pair GetBucketMetrics( TimePeriod period, std::optional now = std::nullopt) const; // Returns the current metrics that have been collected for at least the // specified duration. Returns the metrics and the duration since the // beginning of the covered period. If `now` is nullopt, it excludes metrics // and the time since the last metrics flush. std::pair GetAggregateEndingNow( Duration duration, std::optional now = Clock::now()) const; // Flushes the current pending bucket into the exponential metrics and // advances time by the elapsed duration, potentially processing multiple // ticks. Returns the largest time period that was updated by this advance // (all smaller periods are also updated). // // Note: Advance() should typically be called more frequently than the tick // frequency. When called less frequently (advancing multiple ticks at once), // live statistics go into the first bucket and subsequent buckets are padded // with empty metrics. TimePeriod Advance(Clock::time_point now = Clock::now()); constexpr Duration GetResolution() const { return kPeriodDuration; } // Primary constructor for new protobuf-based metrics. // Takes two free functions that define the metric's behavior. ExponentialAggregator(ClearFn clear_fn, UpdateFromFn update_from_fn); // Default constructor for backward compatibility with old C++ metrics. // This will only compile if `Metric` has Reset() and MergeFrom() methods. ExponentialAggregator(); private: // Constexpr power of 2 using bit shifts static constexpr double constexpr_pow2(int exp) { if (exp >= 0) { return static_cast(1LL << exp); } else { return 1.0 / static_cast(1LL << (-exp)); } } static constexpr Duration kPeriodDuration = std::chrono::duration_cast(std::chrono::duration( constexpr_pow2(static_cast(Resolution)))); static size_t GetBucketIndex(TimePeriod period) { return static_cast(period) - static_cast(Resolution); } // The aggregation strategy is analogous to a binary counter. `tick_count_` // represents the counter's value, and the `buckets_` array corresponds to its // bits, each covering an exponentially larger time period. // // Advancing time increments `tick_count_`. When a bit flips from 1 to 0, // its bucket's metric is merged (the "carry") into the next higher bucket. // // Note that buckets are never empty. A bucket whose corresponding bit in // `tick_count_` is '0' simply holds the last complete metric for its time // period. This ensures that a valid, historical metric is always available // for the bucket query. ClearFn clear_fn_; UpdateFromFn update_from_fn_; mutable absl::Mutex mutex_; size_t tick_count_ ABSL_GUARDED_BY(mutex_); // Buckets for each time period, starting from Resolution. std::vector buckets_ ABSL_GUARDED_BY(mutex_); Clock::time_point last_tick_time_ ABSL_GUARDED_BY(mutex_); mutable absl::Mutex pending_bucket_mutex_ ABSL_ACQUIRED_AFTER(mutex_); Metric pending_bucket_ ABSL_GUARDED_BY(pending_bucket_mutex_); }; template ExponentialAggregator::ExponentialAggregator( ClearFn clear_fn, UpdateFromFn update_from_fn) : clear_fn_(std::move(clear_fn)), update_from_fn_(std::move(update_from_fn)), tick_count_(0), last_tick_time_(Clock::now()) {} template ExponentialAggregator::ExponentialAggregator() : tick_count_(0), last_tick_time_(Clock::now()) { // For backward compatibility, use metric methods if available if constexpr (requires(Metric& m, const Metric& other) { m.Reset(); m.MergeFrom(other); }) { clear_fn_ = [](Metric& m) { m.Reset(); }; update_from_fn_ = [](Metric& dest, const Metric& src) { dest.MergeFrom(src); }; } else { static_assert(sizeof(Metric) == 0, "Metric type must have Reset() and MergeFrom() methods or " "use function-based constructor"); } } template void ExponentialAggregator::Reset( std::chrono::steady_clock::time_point now) { absl::MutexLock lock(&mutex_); tick_count_ = 0; buckets_.clear(); last_tick_time_ = now; absl::MutexLock pending_bucket_lock(&pending_bucket_mutex_); clear_fn_(pending_bucket_); } template void ExponentialAggregator::RecordMetrics(Metric&& metric) { absl::MutexLock lock(&pending_bucket_mutex_); update_from_fn_(pending_bucket_, metric); clear_fn_(metric); } template auto ExponentialAggregator::GetBucketMetrics( TimePeriod period, std::optional now) const -> std::pair { absl::MutexLock lock(&mutex_); const size_t index = GetBucketIndex(period); const Duration duration_since_update = kPeriodDuration * (tick_count_ % (1ULL << index)) + (now.has_value() ? Duration(*now - last_tick_time_) : Duration::zero()); if (index >= buckets_.size()) return {Metric(), duration_since_update}; return {buckets_[index], duration_since_update}; } template auto ExponentialAggregator::GetAggregateEndingNow( Duration duration, std::optional now) const -> std::pair { Duration result_duration = Duration::zero(); Metric result; { absl::MutexLock lock(&mutex_); if (now.has_value()) { // If we'll use pending bucket, remove its duration from the request. // The actual bucket we'll merge in the end as we have to merge newer // into older buckets. Duration duration_since_update = *now - last_tick_time_; duration -= duration_since_update; result_duration += duration_since_update; } if (duration > Duration::zero()) { // Convert the input `duration` into `num_ticks` (the number of base // time periods), rounding up. const auto div = std::div(duration.count(), kPeriodDuration.count()); const size_t num_ticks = div.quot + bool(div.rem); // To cover the remaining `duration`, we select the minimal set of active // historical buckets (where the corresponding bit in `tick_count_` is 1) // that, when combined, meet or exceed the target duration. // // 1. First we determine the bit width of `num_ticks` (e.g., for 13 // (1101b), the width is 4). Create a candidate set of ticks by // masking `tick_count_` to this width. If this masked value is >= // `num_ticks`, the corresponding set of buckets is sufficient. // 2. If the masked value from step 2 is insufficient, we include one // additional bucket. It doesn't matter if it's active or not. size_t masked_ticks = [&]() { const auto bit_width = std::bit_width(num_ticks); const size_t mask = ~((~size_t{0}) << bit_width); const size_t candidate_ticks = tick_count_ & mask; if (candidate_ticks >= num_ticks) return candidate_ticks; // One additional tick is needed. return candidate_ticks | (size_t{1} << bit_width); }(); while (masked_ticks) { // Start merging from the highest bit (older bucket) to the lowest. size_t idx = std::bit_width(masked_ticks) - 1; masked_ticks &= ~(1ULL << idx); if (idx < buckets_.size()) { update_from_fn_(result, buckets_[idx]); result_duration += kPeriodDuration * (1ULL << idx); } } } } if (now.has_value()) { // The pending bucket is merged last, as we have to merge newer into older // buckets. absl::MutexLock lock(&pending_bucket_mutex_); update_from_fn_(result, pending_bucket_); } return {result, result_duration}; } template auto ExponentialAggregator::Advance(Clock::time_point now) -> TimePeriod { absl::MutexLock lock(&mutex_); const int num_ticks_to_advance = (now - last_tick_time_) / kPeriodDuration; if (num_ticks_to_advance <= 0) return TimePeriod::kEmpty; last_tick_time_ += num_ticks_to_advance * kPeriodDuration; Metric live_carry; { // What was pending, now becomes carry. Pending bucket is cleared. absl::MutexLock pending_bucket_lock(&pending_bucket_mutex_); live_carry = std::move(pending_bucket_); clear_fn_(pending_bucket_); } const size_t initial_tick_count = tick_count_; auto one_tick = [&](Metric& carry) { ++tick_count_; for (size_t i = 0;; ++i) { const uint64_t interval_size = 1ULL << i; if ((tick_count_ % interval_size) != 0) break; while (i >= buckets_.size()) buckets_.emplace_back(); // We always merge new into the old, so we swap the carry first, and then // merge into it. std::swap(carry, buckets_[i]); update_from_fn_(carry, buckets_[i]); } }; // Carry the pending bucket into the first tick. one_tick(live_carry); // Then, if more than one tick is requested, we carry the empty bucket // through the remaining ticks. for (int i = 1; i < num_ticks_to_advance; ++i) { Metric empty_carry; one_tick(empty_carry); } // Largest time period is the highest bit that was flipped in the process. // To find it, we XOR the initial tick count with the current one, and // find the highest bit. return static_cast( std::bit_width(initial_tick_count ^ tick_count_) - 1 + static_cast(Resolution)); } } // namespace lczero ================================================ FILE: csrc/utils/metrics/group.h ================================================ #pragma once #include #include "utils/metrics/printer.h" namespace lczero { // Metric is a struct that implements the following interface: // - void Reset(); // Resets the metric to its initial state. // - void MergeFrom(const Metric& other); // Merges another metric into this // one. Note that the incoming always happens later in time, so if e.g. merge // keeps the latest value, it should update the current value with the incoming // one. Used for bucket-to-bucket merging and live data ingestion. // - (optional) std::string_view name() const; // - (optional) std::string ToString() const; // If provided, returns a string // representation of the metric. // Group several metric types together. This allows us to have a single // `MetricGroup` that contains multiple different metrics. template class MetricGroup { public: MetricGroup() = default; // Calls reset on all stats. void Reset(); // Merges each individual stat from `other` into this group. void MergeFrom(const MetricGroup& other); // Merges a single stat from `other` into this group. template void MergeFrom(const T& other); // Gets a const reference to a specific stat record. template const T& Get() const; // Gets a mutable pointer to a specific stat record. template T* GetMutable(); // Calls MetricPrinter for each stat in the group. void Print(MetricPrinter& printer) const; private: std::tuple stats_; }; template void MetricGroup::Reset() { (std::get(stats_).Reset(), ...); } template void MetricGroup::MergeFrom( const MetricGroup& other) { (std::get(stats_).MergeFrom(std::get(other.stats_)), ...); } template template void MetricGroup::MergeFrom(const T& other) { static_assert((std::is_same_v || ...), "Type T must be one of the Stats types"); std::get(stats_).MergeFrom(other); } template template const T& MetricGroup::Get() const { static_assert((std::is_same_v || ...), "Type T must be one of the Stats types"); return std::get(stats_); } template template T* MetricGroup::GetMutable() { static_assert((std::is_same_v || ...), "Type T must be one of the Stats types"); return &std::get(stats_); } template void MetricGroup::Print(MetricPrinter& printer) const { ( [&](const auto& stat) { if constexpr (requires { stat.Print(printer); }) { stat.Print(printer); } }(std::get(stats_)), ...); } } // namespace lczero ================================================ FILE: csrc/utils/metrics/load_metric.h ================================================ #pragma once #include #include #include #include #include "proto/training_metrics.pb.h" namespace lczero { // ABOUTME: Helper class to manage timing logic for LoadMetricProto. // ABOUTME: Tracks active periods and flushes accumulated time to the protobuf // metric. class LoadMetricUpdater { public: using Clock = std::chrono::steady_clock; using Duration = std::chrono::duration; explicit LoadMetricUpdater(Clock::time_point initial_time = Clock::now()) : last_flush_time_(initial_time), is_load_active_(true) {} // Starts tracking load from the given time point. // Returns true if load was previously stopped (successful start). bool LoadStart(Clock::time_point now = Clock::now()) { absl::MutexLock lock(&mutex_); FlushInternal(now); bool was_stopped = !is_load_active_; is_load_active_ = true; return was_stopped; } // Stops tracking load at the given time point. // Returns true if load was previously active (successful stop). bool LoadStop(Clock::time_point now = Clock::now()) { absl::MutexLock lock(&mutex_); FlushInternal(now); bool was_active = is_load_active_; is_load_active_ = false; return was_active; } // Flushes any uncounted load time into the metric. void Flush(Clock::time_point now = Clock::now()) { absl::MutexLock lock(&mutex_); FlushInternal(now); } // Flushes metrics and returns a copy, resetting the internal metric. LoadMetricProto FlushMetrics(Clock::time_point now = Clock::now()) { absl::MutexLock lock(&mutex_); FlushInternal(now); LoadMetricProto result = metric_; metric_.Clear(); return result; } private: // Flushes any uncounted load time into the metric (assumes mutex held). void FlushInternal(Clock::time_point now) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { Duration elapsed = now - last_flush_time_; double elapsed_seconds = elapsed.count(); metric_.set_total_seconds(metric_.total_seconds() + elapsed_seconds); if (is_load_active_) { metric_.set_load_seconds(metric_.load_seconds() + elapsed_seconds); } last_flush_time_ = now; } mutable absl::Mutex mutex_; LoadMetricProto metric_ ABSL_GUARDED_BY(mutex_); Clock::time_point last_flush_time_ ABSL_GUARDED_BY(mutex_); bool is_load_active_ ABSL_GUARDED_BY(mutex_); }; // UpdateFrom function for LoadMetricProto - simple additive behavior inline void UpdateFrom(LoadMetricProto& dest, const LoadMetricProto& src) { if (src.has_name()) dest.set_name(src.name()); dest.set_load_seconds(dest.load_seconds() + src.load_seconds()); dest.set_total_seconds(dest.total_seconds() + src.total_seconds()); } // RAII class to temporarily pause load tracking. class LoadMetricPauser { public: explicit LoadMetricPauser(LoadMetricUpdater& updater) : updater_(updater) { successfully_paused_ = updater_.LoadStop(); } ~LoadMetricPauser() { if (successfully_paused_ && should_resume_) { updater_.LoadStart(); } } // Prevents the pauser from resuming load tracking in destructor. void DoNotResume() { should_resume_ = false; } LoadMetricPauser(const LoadMetricPauser&) = delete; LoadMetricPauser& operator=(const LoadMetricPauser&) = delete; LoadMetricPauser(LoadMetricPauser&&) = delete; LoadMetricPauser& operator=(LoadMetricPauser&&) = delete; private: LoadMetricUpdater& updater_; bool successfully_paused_; bool should_resume_ = true; }; } // namespace lczero ================================================ FILE: csrc/utils/metrics/load_metric_test.cc ================================================ #include "utils/metrics/load_metric.h" #include #include #include #include "proto/training_metrics.pb.h" #include "utils/metrics/exponential_aggregator.h" namespace lczero { namespace training { class LoadMetricTest : public ::testing::Test { protected: using Clock = LoadMetricUpdater::Clock; void SetUp() override { start_time_ = Clock::now(); } Clock::time_point start_time_; }; TEST_F(LoadMetricTest, BasicLoadMetricProto) { LoadMetricProto metric; EXPECT_EQ(metric.load_seconds(), 0.0); EXPECT_EQ(metric.total_seconds(), 0.0); // Test UpdateFrom with LoadMetricUpdater LoadMetricUpdater other_updater(start_time_); other_updater.LoadStart(start_time_); other_updater.LoadStop(start_time_ + std::chrono::milliseconds(500)); LoadMetricProto other = other_updater.FlushMetrics(start_time_ + std::chrono::milliseconds(500)); UpdateFrom(metric, other); EXPECT_NEAR(metric.load_seconds(), 0.5, 1e-6); EXPECT_NEAR(metric.total_seconds(), 0.5, 1e-6); // Test Clear (used to be Reset) metric.Clear(); EXPECT_EQ(metric.load_seconds(), 0.0); EXPECT_EQ(metric.total_seconds(), 0.0); } TEST_F(LoadMetricTest, LoadMetricUpdaterBasic) { LoadMetricUpdater updater(start_time_); auto now = start_time_; // Start load and verify initial state updater.LoadStart(now); LoadMetricProto metric = updater.FlushMetrics(now); EXPECT_EQ(metric.load_seconds(), 0.0); EXPECT_EQ(metric.total_seconds(), 0.0); // Advance time and stop load now += std::chrono::milliseconds(100); updater.LoadStop(now); metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.1, 1e-6); EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6); // Wait idle time, then start again now += std::chrono::milliseconds(50); updater.LoadStart(now); metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.0, 1e-6); // Reset after flush EXPECT_NEAR(metric.total_seconds(), 0.05, 1e-6); // Only idle time now += std::chrono::milliseconds(200); updater.LoadStop(now); metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.2, 1e-6); // 0.2 load EXPECT_NEAR(metric.total_seconds(), 0.2, 1e-6); // 0.2 total } TEST_F(LoadMetricTest, LoadMetricUpdaterFlush) { LoadMetricUpdater updater(start_time_); auto now = start_time_; // Start load updater.LoadStart(now); now += std::chrono::milliseconds(100); // Flush should update the internal metric updater.Flush(now); LoadMetricProto metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.1, 1e-6); EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6); // Continue loading now += std::chrono::milliseconds(50); updater.LoadStop(now); metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.05, 1e-6); // Only new load time EXPECT_NEAR(metric.total_seconds(), 0.05, 1e-6); // Only new total time } TEST_F(LoadMetricTest, LoadMetricProtoMerging) { LoadMetricUpdater updater1(start_time_); LoadMetricUpdater updater2(start_time_); auto now = start_time_; // Create load in updater1 updater1.LoadStart(now); updater1.LoadStop(now + std::chrono::milliseconds(100)); LoadMetricProto metric1 = updater1.FlushMetrics(now + std::chrono::milliseconds(100)); EXPECT_NEAR(metric1.load_seconds(), 0.1, 1e-6); EXPECT_NEAR(metric1.total_seconds(), 0.1, 1e-6); // Create load in updater2 updater2.LoadStart(now); updater2.LoadStop(now + std::chrono::milliseconds(100)); LoadMetricProto metric2 = updater2.FlushMetrics(now + std::chrono::milliseconds(100)); EXPECT_NEAR(metric2.load_seconds(), 0.1, 1e-6); EXPECT_NEAR(metric2.total_seconds(), 0.1, 1e-6); // Merge UpdateFrom(metric1, metric2); EXPECT_NEAR(metric1.load_seconds(), 0.2, 1e-6); EXPECT_NEAR(metric1.total_seconds(), 0.2, 1e-6); EXPECT_NEAR(metric2.load_seconds(), 0.1, 1e-6); // Source unchanged EXPECT_NEAR(metric2.total_seconds(), 0.1, 1e-6); // Source unchanged } TEST_F(LoadMetricTest, LoadMetricProtoMoveSemantics) { // Test that LoadMetricProto move semantics work correctly LoadMetricUpdater source_updater(start_time_); source_updater.LoadStart(start_time_); source_updater.LoadStop(start_time_ + std::chrono::milliseconds(100)); LoadMetricProto source = source_updater.FlushMetrics(start_time_ + std::chrono::milliseconds(100)); EXPECT_NEAR(source.load_seconds(), 0.1, 1e-6); EXPECT_NEAR(source.total_seconds(), 0.1, 1e-6); // Test move construction LoadMetricProto moved_constructed(std::move(source)); EXPECT_NEAR(moved_constructed.load_seconds(), 0.1, 1e-6); EXPECT_NEAR(moved_constructed.total_seconds(), 0.1, 1e-6); // Test move assignment LoadMetricProto move_assigned; LoadMetricUpdater another_updater(start_time_); another_updater.LoadStart(start_time_); another_updater.LoadStop(start_time_ + std::chrono::milliseconds(50)); LoadMetricProto another_source = another_updater.FlushMetrics(start_time_ + std::chrono::milliseconds(50)); EXPECT_NEAR(another_source.load_seconds(), 0.05, 1e-6); EXPECT_NEAR(another_source.total_seconds(), 0.05, 1e-6); move_assigned = std::move(another_source); EXPECT_NEAR(move_assigned.load_seconds(), 0.05, 1e-6); EXPECT_NEAR(move_assigned.total_seconds(), 0.05, 1e-6); // Test UpdateFrom LoadMetricProto dest; UpdateFrom(dest, moved_constructed); EXPECT_NEAR(dest.load_seconds(), 0.1, 1e-6); EXPECT_NEAR(dest.total_seconds(), 0.1, 1e-6); UpdateFrom(dest, move_assigned); EXPECT_NEAR(dest.load_seconds(), 0.15, 1e-6); EXPECT_NEAR(dest.total_seconds(), 0.15, 1e-6); } TEST_F(LoadMetricTest, LoadUtilizationTracking) { LoadMetricUpdater updater(start_time_); auto now = start_time_; // Start with some load time (load is active by default now) now += std::chrono::milliseconds(100); updater.LoadStop(now); // Stop load to create idle time LoadMetricProto metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.1, 1e-6); // 100ms load EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6); // 100ms total // Add some idle time (load is stopped) now += std::chrono::milliseconds(100); updater.LoadStart(now); // Start load again metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.0, 1e-6); // No load time EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6); // 100ms idle // Add more load time now += std::chrono::milliseconds(200); updater.LoadStop(now); metric = updater.FlushMetrics(now); EXPECT_NEAR(metric.load_seconds(), 0.2, 1e-6); // 200ms load EXPECT_NEAR(metric.total_seconds(), 0.2, 1e-6); // 200ms total // Test complete utilization tracking with one updater LoadMetricUpdater total_updater(start_time_); auto total_now = start_time_; // 100ms load (active by default) total_now += std::chrono::milliseconds(100); total_updater.LoadStop(total_now); // 100ms idle total_now += std::chrono::milliseconds(100); total_updater.LoadStart(total_now); // 200ms load total_now += std::chrono::milliseconds(200); LoadMetricProto total_metric = total_updater.FlushMetrics(total_now); // Calculate utilization double utilization = total_metric.load_seconds() / total_metric.total_seconds(); EXPECT_NEAR(utilization, 0.75, 1e-6); // 75% utilization (300ms load / 400ms total) } class LoadMetricProtoIntegrationTest : public ::testing::Test { protected: using TestAggregator = ExponentialAggregator; using Clock = TestAggregator::Clock; void SetUp() override { aggregator_ = std::make_unique( [](LoadMetricProto& m) { m.Clear(); }, [](LoadMetricProto& dest, const LoadMetricProto& src) { UpdateFrom(dest, src); }); start_time_ = Clock::now(); } std::unique_ptr aggregator_; Clock::time_point start_time_; }; TEST_F(LoadMetricProtoIntegrationTest, RecordMetricsWithUpdater) { auto current_time = start_time_; // Create metric with updater, simulate some load LoadMetricUpdater updater(current_time); updater.LoadStart(current_time); current_time += std::chrono::milliseconds(150); // Flush and get metric LoadMetricProto metric = updater.FlushMetrics(current_time); // Record the metric (this should use UpdateFrom + Reset) aggregator_->RecordMetrics(std::move(metric)); // Get live metrics auto [live_metrics, age] = aggregator_->GetAggregateEndingNow( TestAggregator::Duration::zero(), current_time); EXPECT_NEAR(live_metrics.load_seconds(), 0.15, 1e-6); } TEST_F(LoadMetricProtoIntegrationTest, MultipleRecordMetrics) { auto current_time = start_time_; // First metric LoadMetricUpdater updater1(current_time); updater1.LoadStart(current_time); current_time += std::chrono::milliseconds(100); LoadMetricProto metric1 = updater1.FlushMetrics(current_time); aggregator_->RecordMetrics(std::move(metric1)); // Second metric LoadMetricUpdater updater2(current_time); updater2.LoadStart(current_time); current_time += std::chrono::milliseconds(75); LoadMetricProto metric2 = updater2.FlushMetrics(current_time); aggregator_->RecordMetrics(std::move(metric2)); // Get live metrics auto [live_metrics, age] = aggregator_->GetAggregateEndingNow( TestAggregator::Duration::zero(), current_time); EXPECT_NEAR(live_metrics.load_seconds(), 0.175, 1e-6); // 0.1 + 0.075 } TEST_F(LoadMetricProtoIntegrationTest, AdvanceTest) { auto current_time = start_time_; // Add some metrics LoadMetricUpdater updater(current_time); updater.LoadStart(current_time); updater.LoadStop(current_time + std::chrono::milliseconds(100)); LoadMetricProto metric = updater.FlushMetrics(current_time + std::chrono::milliseconds(100)); aggregator_->RecordMetrics(std::move(metric)); // Advance to move live metrics to buckets auto tick_time = start_time_ + aggregator_->GetResolution(); auto period = aggregator_->Advance(tick_time); // Should return the base time period EXPECT_EQ(period, TimePeriod::k16Milliseconds); // Live metrics should be empty after advance auto [live_metrics, age] = aggregator_->GetAggregateEndingNow( TestAggregator::Duration::zero(), tick_time); EXPECT_EQ(live_metrics.load_seconds(), 0.0); } TEST_F(LoadMetricTest, LoadStartStopReturnValues) { LoadMetricUpdater updater(start_time_); // Load starts active, so LoadStart should return false (already active) EXPECT_FALSE(updater.LoadStart()); // LoadStop should return true (was active) EXPECT_TRUE(updater.LoadStop()); // LoadStop again should return false (already stopped) EXPECT_FALSE(updater.LoadStop()); // LoadStart should return true (was stopped) EXPECT_TRUE(updater.LoadStart()); // LoadStart again should return false (already active) EXPECT_FALSE(updater.LoadStart()); } } // namespace training } // namespace lczero int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } ================================================ FILE: csrc/utils/metrics/printer.h ================================================ #pragma once #include #include #include namespace lczero { class MetricPrinter { public: virtual ~MetricPrinter() = default; virtual void StartGroup(std::string_view group_name) = 0; virtual void Print(std::string_view metric_name, const absl::AlphaNum& value) = 0; virtual void EndGroup() = 0; }; class StringMetricPrinter : public MetricPrinter { public: StringMetricPrinter(std::string* output) : output_(output) {} void StartGroup(std::string_view group_name) override { if (!first_group_) absl::StrAppend(output_, ", "); absl::StrAppend(output_, group_name, "={"); first_group_ = false; first_metric_ = true; } void Print(std::string_view metric_name, const absl::AlphaNum& value) override { if (!first_metric_) absl::StrAppend(output_, ", "); absl::StrAppend(output_, metric_name, "=", value); first_metric_ = false; first_group_ = false; } void EndGroup() override { absl::StrAppend(output_, "}"); } private: std::string* output_; bool first_metric_ = true; bool first_group_ = true; }; template std::string MetricToString(const T& metric) { std::string result; StringMetricPrinter printer(&result); metric.Print(printer); return result; } } // namespace lczero ================================================ FILE: csrc/utils/metrics/statistics_metric.h ================================================ #pragma once #include #include "proto/training_metrics.pb.h" namespace lczero { // Helper function to add a sample to StatisticsProtoInt64 inline void AddSample(StatisticsProtoInt64& stats, int64_t value) { stats.set_min(std::min(stats.min(), value)); stats.set_max(std::max(stats.max(), value)); stats.set_sum(stats.sum() + value); stats.set_count(stats.count() + 1); stats.set_latest(value); } // Helper function to add a sample to StatisticsProtoDouble inline void AddSample(StatisticsProtoDouble& stats, double value) { stats.set_min(std::min(stats.min(), value)); stats.set_max(std::max(stats.max(), value)); stats.set_sum(stats.sum() + value); stats.set_count(stats.count() + 1); stats.set_latest(value); } // UpdateFrom function for StatisticsProtoInt64 - merges statistics inline void UpdateFrom(StatisticsProtoInt64& dest, const StatisticsProtoInt64& src) { if (src.count() == 0) return; // Nothing to merge from empty source dest.set_min(std::min(dest.min(), src.min())); dest.set_max(std::max(dest.max(), src.max())); dest.set_sum(dest.sum() + src.sum()); dest.set_count(dest.count() + src.count()); dest.set_latest(src.latest()); // Source is newer, use its latest value } // UpdateFrom function for StatisticsProtoDouble - merges statistics inline void UpdateFrom(StatisticsProtoDouble& dest, const StatisticsProtoDouble& src) { if (src.count() == 0) return; // Nothing to merge from empty source if (src.has_name()) dest.set_name(src.name()); dest.set_min(std::min(dest.min(), src.min())); dest.set_max(std::max(dest.max(), src.max())); dest.set_sum(dest.sum() + src.sum()); dest.set_count(dest.count() + src.count()); dest.set_latest(src.latest()); // Source is newer, use its latest value } } // namespace lczero ================================================ FILE: csrc/utils/metrics/stats_test.cc ================================================ #include #include #include #include #include #include "utils/metrics/exponential_aggregator.h" #include "utils/metrics/group.h" #include "utils/metrics/printer.h" namespace lczero { class CounterMetric { public: CounterMetric() : count_(0) {} CounterMetric(int count) : count_(count) {} void Reset() { count_ = 0; } void MergeFrom(const CounterMetric& other) { count_ += other.count_; } void Print(MetricPrinter& printer) const { printer.StartGroup("CounterMetric"); printer.Print("count", static_cast(count_)); printer.EndGroup(); } int count() const { return count_; } void set_count(int count) { count_ = count; } private: int count_; }; class AverageMetric { public: AverageMetric() : sum_(0), count_(0) {} AverageMetric(double sum, int count) : sum_(sum), count_(count) {} void Reset() { sum_ = 0; count_ = 0; } void MergeFrom(const AverageMetric& other) { sum_ += other.sum_; count_ += other.count_; } void Print(MetricPrinter& printer) const { printer.StartGroup("AverageMetric"); printer.Print("sum", sum_); printer.Print("count", static_cast(count_)); if (count_ > 0) { printer.Print("average", sum_ / count_); } printer.EndGroup(); } double average() const { return count_ > 0 ? sum_ / count_ : 0.0; } void add_sample(double value) { sum_ += value; count_++; } double sum() const { return sum_; } int count() const { return count_; } private: double sum_; int count_; }; class MaxMetric { public: MaxMetric() : max_value_(0), has_value_(false) {} MaxMetric(double max_value) : max_value_(max_value), has_value_(true) {} void Reset() { max_value_ = 0; has_value_ = false; } void MergeFrom(const MaxMetric& other) { if (other.has_value_) { if (!has_value_ || other.max_value_ > max_value_) { max_value_ = other.max_value_; has_value_ = true; } } } void Print(MetricPrinter& printer) const { printer.StartGroup("MaxMetric"); if (has_value_) { printer.Print("max_value", max_value_); printer.Print("has_value", static_cast(1)); } else { printer.Print("has_value", static_cast(0)); } printer.EndGroup(); } double max_value() const { return max_value_; } bool has_value() const { return has_value_; } void set_value(double value) { if (!has_value_ || value > max_value_) { max_value_ = value; has_value_ = true; } } private: double max_value_; bool has_value_; }; // Optional value metric that demonstrates overshadowing behavior class OptionalValueMetric { public: OptionalValueMetric() : value_(std::nullopt) {} OptionalValueMetric(int value) : value_(value) {} void Reset() { value_ = std::nullopt; } void MergeFrom(const OptionalValueMetric& other) { // Only copy the value if the other metric has one (overshadowing behavior) if (other.value_.has_value()) { value_ = other.value_; } } void Print(MetricPrinter& printer) const { printer.StartGroup("OptionalValueMetric"); if (value_.has_value()) { printer.Print("value", value_.value()); printer.Print("has_value", static_cast(1)); } else { printer.Print("has_value", static_cast(0)); } printer.EndGroup(); } std::optional value() const { return value_; } bool has_value() const { return value_.has_value(); } void set_value(int value) { value_ = value; } private: std::optional value_; }; // Test MetricGroup functionality class MetricGroupTest : public ::testing::Test { protected: using TestGroup = MetricGroup; TestGroup group_; }; TEST_F(MetricGroupTest, InitialState) { // Test that metrics are initialized in their default state EXPECT_EQ(group_.Get().count(), 0); EXPECT_EQ(group_.Get().count(), 0); EXPECT_FALSE(group_.Get().has_value()); } TEST_F(MetricGroupTest, GetMutable) { // Test getting mutable references and modifying them auto* counter = group_.GetMutable(); counter->set_count(42); EXPECT_EQ(group_.Get().count(), 42); auto* average = group_.GetMutable(); average->add_sample(10.0); average->add_sample(20.0); EXPECT_EQ(group_.Get().average(), 15.0); auto* max_metric = group_.GetMutable(); max_metric->set_value(100.0); EXPECT_EQ(group_.Get().max_value(), 100.0); } TEST_F(MetricGroupTest, Reset) { // Set up some data group_.GetMutable()->set_count(42); group_.GetMutable()->add_sample(10.0); group_.GetMutable()->set_value(100.0); // Reset and verify everything is back to initial state group_.Reset(); EXPECT_EQ(group_.Get().count(), 0); EXPECT_EQ(group_.Get().count(), 0); EXPECT_FALSE(group_.Get().has_value()); } TEST_F(MetricGroupTest, MergeFromGroup) { // Set up source group TestGroup other; other.GetMutable()->set_count(10); other.GetMutable()->add_sample(5.0); other.GetMutable()->set_value(50.0); // Set up destination group group_.GetMutable()->set_count(20); group_.GetMutable()->add_sample(15.0); group_.GetMutable()->set_value(30.0); // Merge group_.MergeFrom(other); // Verify results EXPECT_EQ(group_.Get().count(), 30); // 20 + 10 EXPECT_EQ(group_.Get().average(), 10.0); // (15 + 5) / 2 EXPECT_EQ(group_.Get().max_value(), 50.0); // max(30, 50) } TEST_F(MetricGroupTest, MergeFromSingleMetric) { // Set up initial state group_.GetMutable()->set_count(20); // Create a single metric to merge CounterMetric counter(15); // Merge single metric group_.MergeFrom(counter); // Verify result EXPECT_EQ(group_.Get().count(), 35); // 20 + 15 } TEST_F(MetricGroupTest, Print) { // Set up data group_.GetMutable()->set_count(42); group_.GetMutable()->add_sample(10.0); group_.GetMutable()->add_sample(20.0); group_.GetMutable()->set_value(100.0); std::string result = MetricToString(group_); // Should contain all metric names and values EXPECT_NE(result.find("CounterMetric"), std::string::npos); EXPECT_NE(result.find("count=42"), std::string::npos); EXPECT_NE(result.find("AverageMetric"), std::string::npos); EXPECT_NE(result.find("average=15"), std::string::npos); // (10+20)/2 EXPECT_NE(result.find("MaxMetric"), std::string::npos); EXPECT_NE(result.find("max_value=100"), std::string::npos); } // Test MetricToString functionality class MetricPrinterTest : public ::testing::Test {}; TEST_F(MetricPrinterTest, StringMetricPrinter) { std::string output; StringMetricPrinter printer(&output); printer.StartGroup("test_group"); printer.Print("metric1", std::string("value1")); printer.Print("metric2", std::string("42")); printer.EndGroup(); EXPECT_EQ(output, "test_group={metric1=value1, metric2=42}"); } TEST_F(MetricPrinterTest, MultipleGroups) { std::string output; StringMetricPrinter printer(&output); printer.StartGroup("group1"); printer.Print("metric1", std::string("value1")); printer.EndGroup(); printer.StartGroup("group2"); printer.Print("metric2", std::string("value2")); printer.EndGroup(); EXPECT_EQ(output, "group1={metric1=value1}, group2={metric2=value2}"); } TEST_F(MetricPrinterTest, EmptyGroup) { std::string output; StringMetricPrinter printer(&output); printer.StartGroup("empty_group"); printer.EndGroup(); EXPECT_EQ(output, "empty_group={}"); } TEST_F(MetricPrinterTest, SizeTOverload) { std::string output; StringMetricPrinter string_printer(&output); MetricPrinter& printer = string_printer; // Use base class interface printer.StartGroup("test_group"); printer.Print("count", static_cast(123)); printer.EndGroup(); EXPECT_EQ(output, "test_group={count=123}"); } TEST_F(MetricPrinterTest, MetricToStringFunction) { CounterMetric counter(123); std::string result = MetricToString(counter); EXPECT_NE(result.find("CounterMetric"), std::string::npos); EXPECT_NE(result.find("123"), std::string::npos); } class ExponentialAggregatorTest : public ::testing::Test { protected: using TestMetric = MetricGroup; using TestAggregator = ExponentialAggregator; void SetUp() override { // Create a fresh aggregator for each test to avoid state contamination aggregator_ = std::make_unique(); start_time_ = TestAggregator::Clock::now(); aggregator_->Reset(start_time_); } std::unique_ptr aggregator_; TestAggregator::Clock::time_point start_time_; }; TEST_F(ExponentialAggregatorTest, RecordMetrics) { TestMetric metric; metric.GetMutable()->set_count(10); metric.GetMutable()->add_sample(5.0); // Update live metrics aggregator_->RecordMetrics(std::move(metric)); // The original metric should be reset after move EXPECT_EQ(metric.Get().count(), 0); EXPECT_EQ(metric.Get().count(), 0); // Get live metrics to verify they were updated auto [live_metrics, age] = aggregator_->GetAggregateEndingNow(TestAggregator::Duration::zero()); EXPECT_EQ(live_metrics.Get().count(), 10); EXPECT_EQ(live_metrics.Get().average(), 5.0); } TEST_F(ExponentialAggregatorTest, MultipleUpdatesLiveMetrics) { // Update multiple times for (int i = 1; i <= 5; ++i) { TestMetric metric; metric.GetMutable()->set_count(i); metric.GetMutable()->add_sample(i * 2.0); aggregator_->RecordMetrics(std::move(metric)); } // Get live metrics auto [live_metrics, age] = aggregator_->GetAggregateEndingNow(TestAggregator::Duration::zero()); EXPECT_EQ(live_metrics.Get().count(), 15); // 1+2+3+4+5 EXPECT_EQ(live_metrics.Get().average(), 6.0); // (2+4+6+8+10)/5 } TEST_F(ExponentialAggregatorTest, Advance) { // Add some live metrics TestMetric metric; metric.GetMutable()->set_count(10); aggregator_->RecordMetrics(std::move(metric)); // Advance to move live metrics to buckets auto tick_time = start_time_ + aggregator_->GetResolution(); auto period = aggregator_->Advance(tick_time); // Should return the base time period EXPECT_EQ(period, TimePeriod::k16Milliseconds); // Live metrics should be empty after tick auto [live_metrics, age] = aggregator_->GetAggregateEndingNow( TestAggregator::Duration::zero(), tick_time); EXPECT_EQ(live_metrics.Get().count(), 0); } TEST_F(ExponentialAggregatorTest, MultipleAdvances) { // Add metrics and tick multiple times to test bucket management auto current_time = start_time_; for (const auto expected_period : { TimePeriod::k16Milliseconds, TimePeriod::k31Milliseconds, TimePeriod::k16Milliseconds, TimePeriod::k63Milliseconds, TimePeriod::k16Milliseconds, TimePeriod::k31Milliseconds, TimePeriod::k16Milliseconds, TimePeriod::k125Milliseconds, }) { TestMetric metric; metric.GetMutable()->set_count(1); aggregator_->RecordMetrics(std::move(metric)); current_time += aggregator_->GetResolution(); auto period = aggregator_->Advance(current_time); EXPECT_EQ(period, expected_period); } } TEST_F(ExponentialAggregatorTest, MultipleAdvancesThreeTicks) { // Add metrics and tick multiple times to test bucket management auto current_time = start_time_; for (const auto expected_period : { TimePeriod::k31Milliseconds, TimePeriod::k63Milliseconds, TimePeriod::k125Milliseconds, TimePeriod::k63Milliseconds, }) { TestMetric metric; metric.GetMutable()->set_count(1); aggregator_->RecordMetrics(std::move(metric)); current_time += aggregator_->GetResolution() * 3; auto period = aggregator_->Advance(current_time); EXPECT_EQ(period, expected_period); } } TEST_F(ExponentialAggregatorTest, AggregationTest) { TestMetric metric; // We do 37 (0b100101) updates. for (int i = 0; i < 37; ++i) { metric.GetMutable()->set_count(i + 200); metric.GetMutable()->set_value(i + 100); aggregator_->RecordMetrics(std::move(metric)); start_time_ += aggregator_->GetResolution(); aggregator_->Advance(start_time_); } // One more tick, but we don't advance this time. metric.GetMutable()->set_count(1001); metric.GetMutable()->set_value(1002); aggregator_->RecordMetrics(std::move(metric)); const auto kRes = aggregator_->GetResolution(); start_time_ += kRes / 3; // Not a tick yet. auto check_completed_bucket = [&](TimePeriod period, int expected_count, std::optional expected_value, TestAggregator::Duration expected_duration, bool include_pending = false) { auto [live_metrics, age] = aggregator_->GetBucketMetrics( period, include_pending ? std::make_optional(start_time_) : std::nullopt); EXPECT_EQ(live_metrics.Get().count(), expected_count); EXPECT_EQ(live_metrics.Get().has_value(), expected_value.has_value()); if (expected_value.has_value()) { EXPECT_EQ(live_metrics.Get().value(), expected_value.value()); } EXPECT_EQ(age, expected_duration); }; check_completed_bucket(TimePeriod::k16Milliseconds, 236, 136, kRes * 0); check_completed_bucket(TimePeriod::k31Milliseconds, 234 + 235, 135, kRes); check_completed_bucket(TimePeriod::k63Milliseconds, 232 + 233 + 234 + 235, 135, kRes); check_completed_bucket(TimePeriod::k125Milliseconds, 224 + 225 + 226 + 227 + 228 + 229 + 230 + 231, 131, kRes * 5); check_completed_bucket(TimePeriod::k125Milliseconds, 224 + 225 + 226 + 227 + 228 + 229 + 230 + 231, 131, kRes * 5 + kRes / 3, true); check_completed_bucket(TimePeriod::k250Milliseconds, 216 + 217 + 218 + 219 + 220 + 221 + 222 + 223 + 224 + 225 + 226 + 227 + 228 + 229 + 230 + 231, 131, kRes * 5); check_completed_bucket(TimePeriod::k500Milliseconds, 200 + 201 + 202 + 203 + 204 + 205 + 206 + 207 + 208 + 209 + 210 + 211 + 212 + 213 + 214 + 215 + 216 + 217 + 218 + 219 + 220 + 221 + 222 + 223 + 224 + 225 + 226 + 227 + 228 + 229 + 230 + 231, 131, kRes * 5); check_completed_bucket(TimePeriod::k1Second, 0, std::nullopt, kRes * 37); auto check_aggregate = [&](TestAggregator::Duration duration, int expected_count, std::optional expected_value, TestAggregator::Duration expected_duration, bool include_pending = false) { auto [live_metrics, age] = aggregator_->GetAggregateEndingNow( duration, include_pending ? std::make_optional(start_time_) : std::nullopt); EXPECT_EQ(live_metrics.Get().count(), expected_count); EXPECT_EQ(live_metrics.Get().has_value(), expected_value.has_value()); if (expected_value.has_value()) { EXPECT_EQ(live_metrics.Get().value(), expected_value.value()); } EXPECT_EQ(age, expected_duration); }; check_aggregate(TestAggregator::Duration::zero(), 0, std::nullopt, kRes * 0); check_aggregate(TestAggregator::Duration::zero(), 1001, 1002, kRes / 3, true); check_aggregate(kRes / 4, 1001, 1002, kRes / 3, true); check_aggregate(kRes / 10, 236, 136, kRes); check_aggregate(kRes * 45 / 10, 232 + 233 + 234 + 235 + 236, 136, kRes * 5); check_aggregate(kRes * 55 / 10, 224 + 225 + 226 + 227 + 228 + 229 + 230 + 231 + 232 + 233 + 234 + 235 + 236, 136, kRes * (5 + 8)); } TEST_F(ExponentialAggregatorTest, ActualVsRequestedTimeCoverage) { // Test that GetAggregateEndingNow returns actual time covered by statistics // rather than requested duration when insufficient historical data exists. // This test recreates the scenario from the existing AggregationTest but // specifically tests the requested vs actual duration behavior. const auto kRes = aggregator_->GetResolution(); // Set up aggregator with several data points like in AggregationTest TestMetric metric; for (int i = 0; i < 10; ++i) { metric.GetMutable()->set_count(i + 200); aggregator_->RecordMetrics(std::move(metric)); start_time_ += kRes; aggregator_->Advance(start_time_); } // Based on AggregationTest line 507: request 4.5 * kRes, get back 5 * kRes // This demonstrates that actual coverage (5 * kRes) can be MORE than // requested (4.5 * kRes) because the aggregator only has specific bucket // sizes available const auto requested_duration = kRes * 45 / 10; // 4.5 * kRes auto [result_metrics, actual_duration] = aggregator_->GetAggregateEndingNow(requested_duration, std::nullopt); // The key test: when requesting 4.5 * kRes, we should get actual time covered // which may be different than the requested amount due to bucket granularity EXPECT_GT(actual_duration, requested_duration); // Actual > requested EXPECT_GT(actual_duration, std::chrono::nanoseconds::zero()); // Verify we got some metrics (non-zero count) EXPECT_GT(result_metrics.Get().count(), 0); // Test that shows the key behavior: when we request more time than available, // we get back only the time that's actually covered by data auto [result_zero, duration_zero] = aggregator_->GetAggregateEndingNow( kRes * 100, std::nullopt); // Request way more // The returned duration should be much less than requested (showing actual vs // requested) const auto huge_request = kRes * 100; EXPECT_LT(duration_zero, huge_request); EXPECT_GT(duration_zero, std::chrono::nanoseconds::zero()); } TEST_F(ExponentialAggregatorTest, ExactDurationTest) { // Simple test: add buckets for exactly 5 seconds, request kAllTime, // ensure we get back exactly 5.0 seconds duration (not more) auto current_time = start_time_; // Add buckets for exactly 5 seconds for (int i = 0; i < 5; ++i) { TestMetric metric; metric.GetMutable()->set_count(100 + i); aggregator_->RecordMetrics(std::move(metric)); current_time += std::chrono::seconds(1); aggregator_->Advance(current_time); } // Request statistics for all time auto [result_metrics, actual_duration] = aggregator_->GetAggregateEndingNow( std::chrono::duration_cast( std::chrono::hours(24 * 365)), // Request way more than 5 seconds std::nullopt); // Should return exactly 5.0 seconds duration (actual time covered) const auto expected_duration = std::chrono::seconds(5); EXPECT_EQ(actual_duration, expected_duration); // Should have all our data const int expected_total = 100 + 101 + 102 + 103 + 104; EXPECT_EQ(result_metrics.Get().count(), expected_total); } } // namespace lczero int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } ================================================ FILE: csrc/utils/queue.h ================================================ #pragma once #include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/fixed_array.h" #include "absl/log/log.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" namespace lczero { // Virtual base class for type-erased handling of queues. class QueueBase { public: virtual ~QueueBase() = default; virtual size_t Size() const = 0; virtual size_t Capacity() const = 0; virtual bool IsClosed() const = 0; virtual void Close() = 0; virtual size_t GetTotalPutCount(bool reset = false) = 0; virtual size_t GetTotalGetCount(bool reset = false) = 0; virtual size_t GetTotalDropCount(bool reset = false) = 0; }; // Exception thrown when queue operations are attempted on a closed queue. class QueueClosedException : public std::runtime_error { public: QueueClosedException() : std::runtime_error("Queue is closed") {} }; // Exception thrown when queue operation is cancelled via stop_token. class QueueRequestCancelled : public std::runtime_error { public: QueueRequestCancelled() : std::runtime_error("Queue request cancelled") {} }; enum class OverflowBehavior { BLOCK, DROP_NEW, KEEP_NEWEST }; // Thread-safe fixed-size circular buffer queue with blocking operations. // Supports both single and batch put/get operations. // The queue automatically closes when all Producer tokens are destroyed. // When closed, Put operations throw immediately, but Get operations only throw // when the queue becomes empty - allowing consumption of remaining elements. template class Queue : public QueueBase { public: // Backwards-compatible alias to support code referring to // Queue::OverflowBehavior. using OverflowBehavior = ::lczero::OverflowBehavior; explicit Queue(size_t capacity, OverflowBehavior overflow_behavior = OverflowBehavior::BLOCK); // RAII token for producers. Queue automatically closes when all producers // are destroyed. All Put operations must go through this class. class Producer { public: explicit Producer(Queue& queue); ~Producer(); // Move constructor and assignment Producer(Producer&& other) noexcept; Producer& operator=(Producer&& other) noexcept; // Disable copy to maintain RAII semantics Producer(const Producer&) = delete; Producer& operator=(const Producer&) = delete; // Puts a single element into the queue. Blocks if queue is full. void Put(const T& item, std::stop_token stop_token = {}); void Put(T&& item, std::stop_token stop_token = {}); // Puts multiple elements into the queue. Blocks if not enough space. void Put(absl::Span items, std::stop_token stop_token = {}); void Put(absl::Span items, std::stop_token stop_token = {}); // Explicitly close this producer, decrementing the producer count void Close(); private: Queue* queue_; }; // Creates a new producer token for this queue. Producer CreateProducer(); // Gets a single element from the queue. Blocks if queue is empty. T Get(std::stop_token stop_token = {}); // Gets exactly count elements from the queue. Blocks until count elements // available. absl::FixedArray Get(size_t count, std::stop_token stop_token = {}); // Gets a single element from the queue if available, returns std::nullopt // if empty. std::optional MaybeGet(); // Returns the current size of the queue. size_t Size() const override; // Returns the capacity of the queue. size_t Capacity() const override; // Explicitly close the queue, preventing further Put operations. void Close() override; // Returns true if the queue is closed. bool IsClosed() const override; // Wait until queue has at least the specified amount of free space. void WaitForRoomAtLeast(size_t room, std::stop_token stop_token = {}); // Wait until queue has at most the specified amount of free space. void WaitForRoomAtMost(size_t room, std::stop_token stop_token = {}); // Wait until queue has at least the specified number of elements. void WaitForSizeAtLeast(size_t size, std::stop_token stop_token = {}); // Wait until queue has at most the specified number of elements. void WaitForSizeAtMost(size_t size, std::stop_token stop_token = {}); // Returns the total number of elements that have been put into the queue. // If reset is true, resets the counter to 0 after returning the value. size_t GetTotalPutCount(bool reset = false) override; // Returns the total number of elements that have been retrieved from the // queue. If reset is true, resets the counter to 0 after returning the value. size_t GetTotalGetCount(bool reset = false) override; // Returns the total number of elements that have been dropped from the queue. // If reset is true, resets the counter to 0 after returning the value. size_t GetTotalDropCount(bool reset = false) override; private: friend class Producer; const size_t capacity_; const OverflowBehavior overflow_behavior_; absl::FixedArray buffer_ ABSL_GUARDED_BY(mutex_); size_t head_ ABSL_GUARDED_BY(mutex_) = 0; size_t tail_ ABSL_GUARDED_BY(mutex_) = 0; size_t size_ ABSL_GUARDED_BY(mutex_) = 0; size_t producer_count_ ABSL_GUARDED_BY(mutex_) = 0; bool closed_ ABSL_GUARDED_BY(mutex_) = false; size_t total_put_count_ ABSL_GUARDED_BY(mutex_) = 0; size_t total_get_count_ ABSL_GUARDED_BY(mutex_) = 0; size_t total_drop_count_ ABSL_GUARDED_BY(mutex_) = 0; mutable absl::Mutex mutex_; absl::CondVar cond_var_; // Internal methods for producer management void RemoveProducer(); // Internal Put methods (called by Producer) void PutInternal(const T& item, std::stop_token stop_token = {}); void PutInternal(T&& item, std::stop_token stop_token = {}); void PutInternal(absl::Span items, std::stop_token stop_token = {}); void PutInternal(absl::Span items, std::stop_token stop_token = {}); // Condition predicates for blocking operations bool CanPutOne() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); bool CanGet() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Additional condition predicates for wait functions bool HasRoomAtLeast(size_t room) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); bool HasRoomAtMost(size_t room) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); bool HasSizeAtLeast(size_t size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); bool HasSizeAtMost(size_t size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); }; // Implementation template Queue::Queue(size_t capacity, OverflowBehavior overflow_behavior) : capacity_(capacity), overflow_behavior_(overflow_behavior), buffer_(capacity) {} // Producer implementation template Queue::Producer::Producer(Queue& queue) : queue_(&queue) { // Producer count is incremented in CreateProducer() VLOG(1) << "Queue@" << static_cast(queue_) << " producer@" << static_cast(this) << " constructed."; } template Queue::Producer::~Producer() { if (queue_) { VLOG(1) << "Queue@" << static_cast(queue_) << " producer@" << static_cast(this) << " destructing."; queue_->RemoveProducer(); } } template Queue::Producer::Producer(Producer&& other) noexcept : queue_(other.queue_) { other.queue_ = nullptr; } template typename Queue::Producer& Queue::Producer::operator=( Producer&& other) noexcept { if (this != &other) { if (queue_) { queue_->RemoveProducer(); } queue_ = other.queue_; other.queue_ = nullptr; } return *this; } template void Queue::Producer::Put(const T& item, std::stop_token stop_token) { queue_->PutInternal(item, stop_token); } template void Queue::Producer::Put(T&& item, std::stop_token stop_token) { queue_->PutInternal(std::move(item), stop_token); } template void Queue::Producer::Put(absl::Span items, std::stop_token stop_token) { queue_->PutInternal(items, stop_token); } template void Queue::Producer::Put(absl::Span items, std::stop_token stop_token) { queue_->PutInternal(items, stop_token); } template void Queue::Producer::Close() { if (queue_) { VLOG(1) << "Queue@" << static_cast(queue_) << " producer@" << static_cast(this) << " close invoked."; queue_->RemoveProducer(); queue_ = nullptr; } } // Queue implementation template typename Queue::Producer Queue::CreateProducer() { absl::MutexLock lock(&mutex_); if (closed_) throw QueueClosedException(); ++producer_count_; return Producer(*this); } template void Queue::RemoveProducer() { absl::MutexLock lock(&mutex_); --producer_count_; if (producer_count_ == 0 && !closed_) { closed_ = true; VLOG(1) << "Queue@" << static_cast(this) << " closed after last producer removed."; cond_var_.SignalAll(); } } template void Queue::PutInternal(const T& item, std::stop_token stop_token) { absl::MutexLock lock(&mutex_); if (closed_) { VLOG(1) << "Queue@" << static_cast(this) << " PutInternal(const&) throwing QueueClosedException;" << " producers=" << producer_count_; throw QueueClosedException(); } ++total_put_count_; switch (overflow_behavior_) { case OverflowBehavior::BLOCK: { std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!CanPutOne()) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } if (closed_) throw QueueClosedException(); break; } case OverflowBehavior::DROP_NEW: if (size_ >= capacity_) { ++total_drop_count_; return; } break; case OverflowBehavior::KEEP_NEWEST: if (size_ >= capacity_) { head_ = (head_ + 1) % capacity_; --size_; ++total_drop_count_; } break; } buffer_[tail_] = item; tail_ = (tail_ + 1) % capacity_; ++size_; cond_var_.SignalAll(); } template void Queue::PutInternal(T&& item, std::stop_token stop_token) { absl::MutexLock lock(&mutex_); if (closed_) { VLOG(1) << "Queue@" << static_cast(this) << " PutInternal(T&&) throwing QueueClosedException;" << " producers=" << producer_count_; throw QueueClosedException(); } ++total_put_count_; switch (overflow_behavior_) { case OverflowBehavior::BLOCK: { std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!CanPutOne()) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } if (closed_) throw QueueClosedException(); break; } case OverflowBehavior::DROP_NEW: if (size_ >= capacity_) { ++total_drop_count_; return; } break; case OverflowBehavior::KEEP_NEWEST: if (size_ >= capacity_) { head_ = (head_ + 1) % capacity_; --size_; ++total_drop_count_; } break; } buffer_[tail_] = std::move(item); tail_ = (tail_ + 1) % capacity_; ++size_; cond_var_.SignalAll(); } template void Queue::PutInternal(absl::Span items, std::stop_token stop_token) { if (items.empty()) return; size_t remaining = items.size(); size_t offset = 0; while (remaining > 0) { absl::MutexLock lock(&mutex_); if (closed_) { VLOG(1) << "Queue@" << static_cast(this) << " PutInternal(span const) throwing QueueClosedException;" << " producers=" << producer_count_; throw QueueClosedException(); } size_t batch_size; switch (overflow_behavior_) { case OverflowBehavior::BLOCK: { std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!CanPutOne()) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } if (closed_) throw QueueClosedException(); batch_size = std::min(remaining, capacity_ - size_); break; } case OverflowBehavior::DROP_NEW: batch_size = std::min(remaining, capacity_ - size_); if (batch_size == 0) { total_put_count_ += remaining; total_drop_count_ += remaining; return; } break; case OverflowBehavior::KEEP_NEWEST: batch_size = std::min(remaining, capacity_); while (size_ + batch_size > capacity_) { head_ = (head_ + 1) % capacity_; --size_; ++total_drop_count_; } break; } for (size_t i = 0; i < batch_size; ++i) { buffer_[tail_] = items[offset + i]; tail_ = (tail_ + 1) % capacity_; ++size_; } total_put_count_ += batch_size; cond_var_.SignalAll(); offset += batch_size; remaining -= batch_size; } } template void Queue::PutInternal(absl::Span items, std::stop_token stop_token) { if (items.empty()) return; size_t remaining = items.size(); size_t offset = 0; while (remaining > 0) { absl::MutexLock lock(&mutex_); if (closed_) { VLOG(1) << "Queue@" << static_cast(this) << " PutInternal(span) throwing QueueClosedException;" << " producers=" << producer_count_; throw QueueClosedException(); } size_t batch_size; switch (overflow_behavior_) { case OverflowBehavior::BLOCK: { std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!CanPutOne()) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } if (closed_) throw QueueClosedException(); batch_size = std::min(remaining, capacity_ - size_); break; } case OverflowBehavior::DROP_NEW: batch_size = std::min(remaining, capacity_ - size_); if (batch_size == 0) { total_put_count_ += remaining; total_drop_count_ += remaining; return; } break; case OverflowBehavior::KEEP_NEWEST: batch_size = std::min(remaining, capacity_); while (size_ + batch_size > capacity_) { head_ = (head_ + 1) % capacity_; --size_; ++total_drop_count_; } break; } for (size_t i = 0; i < batch_size; ++i) { buffer_[tail_] = std::move(items[offset + i]); tail_ = (tail_ + 1) % capacity_; ++size_; } total_put_count_ += batch_size; cond_var_.SignalAll(); offset += batch_size; remaining -= batch_size; } } template T Queue::Get(std::stop_token stop_token) { absl::MutexLock lock(&mutex_); std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!CanGet()) { if (closed_ && size_ == 0) { VLOG(1) << "Queue@" << static_cast(this) << " Get() throwing QueueClosedException; producers=" << producer_count_; throw QueueClosedException(); } if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } if (closed_ && size_ == 0) { VLOG(1) << "Queue@" << static_cast(this) << " Get() throwing QueueClosedException; producers=" << producer_count_; throw QueueClosedException(); } T item = std::move(buffer_[head_]); head_ = (head_ + 1) % capacity_; --size_; ++total_get_count_; cond_var_.SignalAll(); return item; } template absl::FixedArray Queue::Get(size_t count, std::stop_token stop_token) { if (count == 0) return absl::FixedArray(0); absl::FixedArray result(count); size_t remaining = count; size_t offset = 0; while (remaining > 0) { absl::MutexLock lock(&mutex_); std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!CanGet()) { if (closed_ && size_ == 0) { VLOG(1) << "Queue@" << static_cast(this) << " Get(" << count << ") throwing QueueClosedException; producers=" << producer_count_; throw QueueClosedException(); } if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } if (closed_ && size_ == 0) { VLOG(1) << "Queue@" << static_cast(this) << " Get(" << count << ") throwing QueueClosedException; producers=" << producer_count_; throw QueueClosedException(); } size_t batch_size = std::min(remaining, size_); for (size_t i = 0; i < batch_size; ++i) { result[offset + i] = std::move(buffer_[head_]); head_ = (head_ + 1) % capacity_; --size_; ++total_get_count_; } cond_var_.SignalAll(); offset += batch_size; remaining -= batch_size; } return result; } template std::optional Queue::MaybeGet() { absl::MutexLock lock(&mutex_); if (size_ == 0) return std::nullopt; T item = std::move(buffer_[head_]); head_ = (head_ + 1) % capacity_; --size_; ++total_get_count_; cond_var_.SignalAll(); return item; } template size_t Queue::Size() const { absl::MutexLock lock(&mutex_); return size_; } template size_t Queue::Capacity() const { return capacity_; } template void Queue::Close() { absl::MutexLock lock(&mutex_); if (!closed_) { closed_ = true; VLOG(1) << "Queue@" << static_cast(this) << " closed explicitly; producers=" << producer_count_; cond_var_.SignalAll(); } } template bool Queue::IsClosed() const { absl::MutexLock lock(&mutex_); return closed_; } template void Queue::WaitForRoomAtLeast(size_t room, std::stop_token stop_token) { absl::MutexLock lock(&mutex_); std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!HasRoomAtLeast(room)) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } } template void Queue::WaitForRoomAtMost(size_t room, std::stop_token stop_token) { absl::MutexLock lock(&mutex_); std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!HasRoomAtMost(room)) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } } template void Queue::WaitForSizeAtLeast(size_t size, std::stop_token stop_token) { absl::MutexLock lock(&mutex_); std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!HasSizeAtLeast(size)) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } } template void Queue::WaitForSizeAtMost(size_t size, std::stop_token stop_token) { absl::MutexLock lock(&mutex_); std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); }); while (!HasSizeAtMost(size)) { if (closed_) throw QueueClosedException(); if (stop_token.stop_requested()) throw QueueRequestCancelled(); cond_var_.Wait(&mutex_); } } template bool Queue::CanPutOne() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return closed_ || size_ < capacity_; } template bool Queue::CanGet() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return closed_ || size_ > 0; } template bool Queue::HasRoomAtLeast(size_t room) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return capacity_ - size_ >= room; } template bool Queue::HasRoomAtMost(size_t room) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return capacity_ - size_ <= room; } template bool Queue::HasSizeAtLeast(size_t size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return size_ >= size; } template bool Queue::HasSizeAtMost(size_t size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return size_ <= size; } template size_t Queue::GetTotalPutCount(bool reset) { absl::MutexLock lock(&mutex_); size_t count = total_put_count_; if (reset) total_put_count_ = 0; return count; } template size_t Queue::GetTotalGetCount(bool reset) { absl::MutexLock lock(&mutex_); size_t count = total_get_count_; if (reset) total_get_count_ = 0; return count; } template size_t Queue::GetTotalDropCount(bool reset) { absl::MutexLock lock(&mutex_); size_t count = total_drop_count_; if (reset) total_drop_count_ = 0; return count; } } // namespace lczero ================================================ FILE: csrc/utils/queue_test.cc ================================================ // ABOUTME: Comprehensive unit tests for the Queue template class // ABOUTME: Tests thread-safe operations, blocking behavior, and edge cases #include "utils/queue.h" #include #include #include #include #include #include namespace lczero { class QueueTest : public ::testing::Test { protected: void SetUp() override {} }; // Basic functionality tests TEST_F(QueueTest, ConstructorCreatesEmptyQueue) { Queue queue(5); EXPECT_EQ(queue.Size(), 0); EXPECT_EQ(queue.Capacity(), 5); } TEST_F(QueueTest, SinglePutGet) { Queue queue(5); { auto producer = queue.CreateProducer(); producer.Put(42); EXPECT_EQ(queue.Size(), 1); } // Producer destroyed here, queue closes int value = queue.Get(); EXPECT_EQ(value, 42); EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, MovePutGet) { Queue> queue(5); { auto producer = queue.CreateProducer(); auto ptr = std::make_unique(42); producer.Put(std::move(ptr)); EXPECT_EQ(queue.Size(), 1); } // Producer destroyed here, queue closes auto result = queue.Get(); EXPECT_EQ(*result, 42); EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, MultiplePutGet) { Queue queue(5); { auto producer = queue.CreateProducer(); for (int i = 0; i < 5; ++i) { producer.Put(i); } EXPECT_EQ(queue.Size(), 5); } // Producer destroyed here, queue closes for (int i = 0; i < 5; ++i) { int value = queue.Get(); EXPECT_EQ(value, i); } EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, CircularBufferBehavior) { Queue queue(3); auto producer = queue.CreateProducer(); // Fill queue producer.Put(1); producer.Put(2); producer.Put(3); // Get one item, put another EXPECT_EQ(queue.Get(), 1); producer.Put(4); // Verify remaining items EXPECT_EQ(queue.Get(), 2); EXPECT_EQ(queue.Get(), 3); EXPECT_EQ(queue.Get(), 4); } // Batch operations tests TEST_F(QueueTest, BatchPutConstSpan) { Queue queue(5); std::vector items = {1, 2, 3}; { auto producer = queue.CreateProducer(); producer.Put(absl::Span(items)); EXPECT_EQ(queue.Size(), 3); } // Producer destroyed here, queue closes for (int i = 0; i < 3; ++i) { EXPECT_EQ(queue.Get(), i + 1); } } TEST_F(QueueTest, BatchPutMoveSpan) { Queue> queue(5); std::vector> items; items.push_back(std::make_unique(1)); items.push_back(std::make_unique(2)); items.push_back(std::make_unique(3)); { auto producer = queue.CreateProducer(); producer.Put(absl::Span>(items)); EXPECT_EQ(queue.Size(), 3); } // Producer destroyed here, queue closes for (int i = 0; i < 3; ++i) { auto result = queue.Get(); EXPECT_EQ(*result, i + 1); } } TEST_F(QueueTest, BatchPutEmptySpan) { Queue queue(5); std::vector empty_items; { auto producer = queue.CreateProducer(); producer.Put(absl::Span(empty_items)); EXPECT_EQ(queue.Size(), 0); } // Producer destroyed here, queue closes } TEST_F(QueueTest, BatchGet) { Queue queue(5); { auto producer = queue.CreateProducer(); for (int i = 0; i < 5; ++i) { producer.Put(i); } } // Producer destroyed here, queue closes auto result = queue.Get(3); EXPECT_EQ(result.size(), 3); for (int i = 0; i < 3; ++i) { EXPECT_EQ(result[i], i); } EXPECT_EQ(queue.Size(), 2); } TEST_F(QueueTest, BatchGetZeroCount) { Queue queue(5); { auto producer = queue.CreateProducer(); producer.Put(42); } // Producer destroyed here, queue closes auto result = queue.Get(0); EXPECT_EQ(result.size(), 0); EXPECT_EQ(queue.Size(), 1); } // Edge cases and error conditions TEST_F(QueueTest, CapacityOne) { Queue queue(1); { auto producer = queue.CreateProducer(); producer.Put(42); EXPECT_EQ(queue.Size(), 1); } // Producer destroyed here, queue closes EXPECT_EQ(queue.Get(), 42); EXPECT_EQ(queue.Size(), 0); } // Tests for operations when all producer tokens are destroyed TEST_F(QueueTest, CreateProducerOnClosedQueue) { Queue queue(5); // Create and immediately destroy producer to close queue { auto producer = queue.CreateProducer(); } // Trying to create a new producer after queue is closed results in an // exception. EXPECT_THROW(queue.CreateProducer(), QueueClosedException); } TEST_F(QueueTest, GetOnClosedQueue) { Queue queue(5); // Create and immediately destroy producer to close queue { auto producer = queue.CreateProducer(); } EXPECT_THROW(queue.Get(), QueueClosedException); } TEST_F(QueueTest, BatchGetOnClosedQueue) { Queue queue(5); // Create and immediately destroy producer to close queue { auto producer = queue.CreateProducer(); } EXPECT_THROW(queue.Get(3), QueueClosedException); } // Thread safety tests TEST_F(QueueTest, SingleProducerSingleConsumer) { Queue queue(10); std::atomic producer_done{false}; std::vector consumed; std::thread producer([&queue, &producer_done]() { auto prod = queue.CreateProducer(); for (int i = 0; i < 100; ++i) { prod.Put(i); } producer_done = true; // Producer destroyed here, closing the queue }); std::thread consumer([&queue, &consumed, &producer_done]() { int value; while (!producer_done || queue.Size() > 0) { try { value = queue.Get(); consumed.push_back(value); } catch (const QueueClosedException&) { break; } } }); producer.join(); consumer.join(); EXPECT_EQ(consumed.size(), 100); for (int i = 0; i < 100; ++i) { EXPECT_EQ(consumed[i], i); } } TEST_F(QueueTest, MultipleProducersMultipleConsumers) { Queue queue(10); constexpr int num_producers = 2; constexpr int items_per_producer = 5; constexpr int total_items = num_producers * items_per_producer; std::vector all_consumed; std::vector producers; // Use a single producer token that we control explicitly auto producer_token = queue.CreateProducer(); // Start producers - they all share the same producer token via reference for (int p = 0; p < num_producers; ++p) { producers.emplace_back([&producer_token, p]() { for (int i = 0; i < items_per_producer; ++i) { int value = p * items_per_producer + i; producer_token.Put(value); } }); } // Wait for all producers to finish for (auto& producer : producers) { producer.join(); } // Now explicitly close the queue by destroying the producer token { auto temp = std::move(producer_token); } // Queue is now closed // Now consume all items from the closed queue for (int i = 0; i < total_items; ++i) { all_consumed.push_back(queue.Get()); } // Verify all items were consumed EXPECT_EQ(all_consumed.size(), total_items); EXPECT_EQ(queue.Size(), 0); // Trying to get one more should throw EXPECT_THROW(queue.Get(), QueueClosedException); } TEST_F(QueueTest, BlockingBehaviorOnFullQueue) { Queue queue(2); std::promise about_to_block; std::future about_to_block_future = about_to_block.get_future(); std::atomic put_completed{false}; auto producer = queue.CreateProducer(); // Fill the queue producer.Put(1); producer.Put(2); std::thread blocker([&producer, &about_to_block, &put_completed]() { about_to_block.set_value(); // Signal we're about to block producer.Put(3); // This should block put_completed = true; }); // Wait for thread to signal it's about to block about_to_block_future.wait(); EXPECT_FALSE(put_completed); // Make space in the queue EXPECT_EQ(queue.Get(), 1); blocker.join(); EXPECT_TRUE(put_completed); EXPECT_EQ(queue.Size(), 2); } TEST_F(QueueTest, BlockingBehaviorOnEmptyQueue) { Queue queue(5); std::promise about_to_block; std::future about_to_block_future = about_to_block.get_future(); std::atomic get_completed{false}; std::atomic result{-1}; auto producer = queue.CreateProducer(); std::thread blocker([&queue, &about_to_block, &get_completed, &result]() { about_to_block.set_value(); // Signal we're about to block result = queue.Get(); // This should block get_completed = true; }); // Wait for thread to signal it's about to block about_to_block_future.wait(); EXPECT_FALSE(get_completed); // Put an item in the queue producer.Put(42); blocker.join(); EXPECT_TRUE(get_completed); EXPECT_EQ(result, 42); } TEST_F(QueueTest, ProducerDestructionUnblocksWaitingGet) { Queue queue(5); // Empty queue std::promise about_to_block; std::future about_to_block_future = about_to_block.get_future(); std::atomic exception_thrown{false}; // Create a producer to keep queue open initially std::unique_ptr::Producer> producer = std::make_unique::Producer>(queue.CreateProducer()); std::thread blocker([&queue, &about_to_block, &exception_thrown]() { about_to_block.set_value(); // Signal we're about to block try { queue.Get(); // This should block } catch (const QueueClosedException&) { exception_thrown = true; } }); // Wait for thread to signal it's about to block about_to_block_future.wait(); EXPECT_FALSE(exception_thrown); // Destroy the producer - this should close queue and unblock the waiting // Get() producer.reset(); blocker.join(); EXPECT_TRUE(exception_thrown); } // Test: Get() should not throw when queue is closed but has elements TEST_F(QueueTest, GetFromClosedQueueWithElements) { Queue queue(5); // Put some elements in the queue, then destroy producer to close it { auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); producer.Put(3); EXPECT_EQ(queue.Size(), 3); } // Producer destroyed here, queue closes // Should be able to get elements that were already in the queue EXPECT_EQ(queue.Get(), 1); EXPECT_EQ(queue.Get(), 2); EXPECT_EQ(queue.Get(), 3); EXPECT_EQ(queue.Size(), 0); // Only now should Get() throw when queue is empty and closed EXPECT_THROW(queue.Get(), QueueClosedException); } TEST_F(QueueTest, BatchGetFromClosedQueueWithElements) { Queue queue(5); // Put some elements in the queue, then destroy producer to close it { auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); producer.Put(3); EXPECT_EQ(queue.Size(), 3); } // Producer destroyed here, queue closes // Should be able to get elements that were already in the queue auto result = queue.Get(2); EXPECT_EQ(result.size(), 2); EXPECT_EQ(result[0], 1); EXPECT_EQ(result[1], 2); EXPECT_EQ(queue.Size(), 1); // Get remaining element EXPECT_EQ(queue.Get(), 3); EXPECT_EQ(queue.Size(), 0); // Only now should Get() throw when queue is empty and closed EXPECT_THROW(queue.Get(1), QueueClosedException); } // Test producer token mechanism specifically TEST_F(QueueTest, ProducerTokenMechanism) { Queue queue(5); // Create multiple producers auto producer1 = queue.CreateProducer(); auto producer2 = queue.CreateProducer(); // Both should be able to put items producer1.Put(1); producer2.Put(2); EXPECT_EQ(queue.Size(), 2); // Destroy one producer - queue should still be open { auto temp = std::move(producer1); } // producer1 is destroyed here producer2.Put(3); EXPECT_EQ(queue.Size(), 3); // Destroy last producer - queue should close { auto temp = std::move(producer2); } // producer2 is destroyed here // Should still be able to get existing items EXPECT_EQ(queue.Get(), 1); EXPECT_EQ(queue.Get(), 2); EXPECT_EQ(queue.Get(), 3); // But trying to get more should throw EXPECT_THROW(queue.Get(), QueueClosedException); } TEST_F(QueueTest, ProducerMoveSemantics) { Queue queue(5); auto producer1 = queue.CreateProducer(); producer1.Put(42); // Move constructor auto producer2 = std::move(producer1); producer2.Put(43); EXPECT_EQ(queue.Size(), 2); // Create another producer and use move assignment auto producer3 = queue.CreateProducer(); producer3 = std::move(producer2); producer3.Put(44); EXPECT_EQ(queue.Size(), 3); // Destroy the last producer { auto temp = std::move(producer3); } // producer3 is destroyed here // Should be able to get all items EXPECT_EQ(queue.Get(), 42); EXPECT_EQ(queue.Get(), 43); EXPECT_EQ(queue.Get(), 44); EXPECT_THROW(queue.Get(), QueueClosedException); } // Tests for Put operations on closed queue TEST_F(QueueTest, PutOnClosedQueueThrowsException) { Queue queue(5); // Create producer and close it auto producer = queue.CreateProducer(); queue.Close(); // All Put operations should throw on closed queue EXPECT_THROW(producer.Put(42), QueueClosedException); EXPECT_THROW(producer.Put(std::move(42)), QueueClosedException); std::vector items = {1, 2, 3}; EXPECT_THROW(producer.Put(absl::Span(items)), QueueClosedException); EXPECT_THROW(producer.Put(absl::Span(items)), QueueClosedException); } TEST_F(QueueTest, PutOnClosedQueueAfterProducerDestruction) { Queue queue(5); // Create producer, add item, then close by destroying all producers auto producer = queue.CreateProducer(); producer.Put(1); { auto temp_producer = std::move(producer); } // All producers destroyed, queue closed // Try to create new producer after close EXPECT_THROW(queue.CreateProducer(), QueueClosedException); } TEST_F(QueueTest, BatchPutOnClosedQueueThrowsException) { Queue queue(10); auto producer = queue.CreateProducer(); queue.Close(); // Batch put operations should throw on closed queue std::vector items = {1, 2, 3, 4, 5}; EXPECT_THROW(producer.Put(absl::Span(items)), QueueClosedException); std::vector mutable_items = {6, 7, 8}; EXPECT_THROW(producer.Put(absl::Span(mutable_items)), QueueClosedException); } TEST_F(QueueTest, PublicCloseMethod) { Queue queue(5); auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); // Explicitly close the queue using public Close() method queue.Close(); // Put operations should now throw EXPECT_THROW(producer.Put(3), QueueClosedException); // But Get operations should still work for existing items EXPECT_EQ(queue.Get(), 1); EXPECT_EQ(queue.Get(), 2); // Get should throw when queue is empty and closed EXPECT_THROW(queue.Get(), QueueClosedException); } TEST_F(QueueTest, CloseUnblocksWaitingSinglePut) { Queue queue(2); // Small capacity auto producer = queue.CreateProducer(); // Fill the queue producer.Put(1); producer.Put(2); std::promise about_to_block; std::future about_to_block_future = about_to_block.get_future(); std::atomic exception_thrown{false}; std::thread blocker([&producer, &about_to_block, &exception_thrown]() { about_to_block.set_value(); // Signal we're about to block try { producer.Put(3); // This should block since queue is full } catch (const QueueClosedException&) { exception_thrown = true; } }); // Wait for thread to signal it's about to block about_to_block_future.wait(); EXPECT_FALSE(exception_thrown); // Close the queue - this should unblock the waiting Put() queue.Close(); blocker.join(); EXPECT_TRUE(exception_thrown); } TEST_F(QueueTest, CloseUnblocksWaitingBatchPut) { Queue queue(3); // Small capacity auto producer = queue.CreateProducer(); // Fill the queue partially producer.Put(1); producer.Put(2); std::promise about_to_block; std::future about_to_block_future = about_to_block.get_future(); std::atomic exception_thrown{false}; std::thread blocker([&producer, &about_to_block, &exception_thrown]() { about_to_block.set_value(); // Signal we're about to block try { std::vector items = {3, 4, 5}; // Need 3 slots but only 1 available producer.Put(absl::Span(items)); // This should block } catch (const QueueClosedException&) { exception_thrown = true; } }); // Wait for thread to signal it's about to block about_to_block_future.wait(); EXPECT_FALSE(exception_thrown); // Close the queue - this should unblock the waiting batch Put() queue.Close(); blocker.join(); EXPECT_TRUE(exception_thrown); } // Tests for new wait functions TEST_F(QueueTest, WaitForRoomAtLeast) { Queue queue(5); auto producer = queue.CreateProducer(); // Initially empty queue should have room >= 5 queue.WaitForRoomAtLeast(5); EXPECT_EQ(queue.Size(), 0); // Fill queue partially producer.Put(1); producer.Put(2); // Should have room >= 3 queue.WaitForRoomAtLeast(3); EXPECT_EQ(queue.Size(), 2); // Fill more producer.Put(3); producer.Put(4); // Should have room >= 1 queue.WaitForRoomAtLeast(1); EXPECT_EQ(queue.Size(), 4); // Test blocking behavior producer.Put(5); // Queue is now full std::promise wait_started; std::future wait_started_future = wait_started.get_future(); std::atomic wait_completed{false}; std::thread waiter([&queue, &wait_started, &wait_completed]() { wait_started.set_value(); queue.WaitForRoomAtLeast(2); // Should block until 2 slots are free wait_completed = true; }); wait_started_future.wait(); std::this_thread::sleep_for(std::chrono::milliseconds(10)); EXPECT_FALSE(wait_completed); // Free up space queue.Get(); queue.Get(); waiter.join(); EXPECT_TRUE(wait_completed); } TEST_F(QueueTest, WaitForRoomAtMost) { Queue queue(5); auto producer = queue.CreateProducer(); // Fill queue partially producer.Put(1); producer.Put(2); producer.Put(3); // Should wait until room <= 2 (currently room = 2) queue.WaitForRoomAtMost(2); EXPECT_EQ(queue.Size(), 3); // Test blocking behavior std::promise wait_started; std::future wait_started_future = wait_started.get_future(); std::atomic wait_completed{false}; std::thread waiter([&queue, &wait_started, &wait_completed]() { wait_started.set_value(); queue.WaitForRoomAtMost(1); // Should block until room <= 1 wait_completed = true; }); wait_started_future.wait(); std::this_thread::sleep_for(std::chrono::milliseconds(10)); EXPECT_FALSE(wait_completed); // Add one more item to make room = 1 producer.Put(4); waiter.join(); EXPECT_TRUE(wait_completed); EXPECT_EQ(queue.Size(), 4); } TEST_F(QueueTest, WaitForSizeAtLeast) { Queue queue(5); auto producer = queue.CreateProducer(); // Test blocking behavior on empty queue std::promise wait_started; std::future wait_started_future = wait_started.get_future(); std::atomic wait_completed{false}; std::thread waiter([&queue, &wait_started, &wait_completed]() { wait_started.set_value(); queue.WaitForSizeAtLeast(3); // Should block until size >= 3 wait_completed = true; }); wait_started_future.wait(); std::this_thread::sleep_for(std::chrono::milliseconds(10)); EXPECT_FALSE(wait_completed); // Add items producer.Put(1); producer.Put(2); std::this_thread::sleep_for(std::chrono::milliseconds(10)); EXPECT_FALSE(wait_completed); producer.Put(3); // Now size = 3 waiter.join(); EXPECT_TRUE(wait_completed); EXPECT_EQ(queue.Size(), 3); } TEST_F(QueueTest, WaitForSizeAtMost) { Queue queue(5); auto producer = queue.CreateProducer(); // Initially empty, size <= 3 queue.WaitForSizeAtMost(3); EXPECT_EQ(queue.Size(), 0); // Fill queue producer.Put(1); producer.Put(2); producer.Put(3); producer.Put(4); producer.Put(5); // Test blocking behavior std::promise wait_started; std::future wait_started_future = wait_started.get_future(); std::atomic wait_completed{false}; std::thread waiter([&queue, &wait_started, &wait_completed]() { wait_started.set_value(); queue.WaitForSizeAtMost(2); // Should block until size <= 2 wait_completed = true; }); wait_started_future.wait(); std::this_thread::sleep_for(std::chrono::milliseconds(10)); EXPECT_FALSE(wait_completed); // Remove items queue.Get(); queue.Get(); std::this_thread::sleep_for(std::chrono::milliseconds(10)); EXPECT_FALSE(wait_completed); queue.Get(); // Now size = 2 waiter.join(); EXPECT_TRUE(wait_completed); EXPECT_EQ(queue.Size(), 2); } TEST_F(QueueTest, WaitFunctionsEdgeCases) { Queue queue(3); auto producer = queue.CreateProducer(); // Wait for room = 0 should work when queue is full producer.Put(1); producer.Put(2); producer.Put(3); queue.WaitForRoomAtMost(0); EXPECT_EQ(queue.Size(), 3); // Wait for size = 0 should work when queue is empty queue.Get(); queue.Get(); queue.Get(); queue.WaitForSizeAtMost(0); EXPECT_EQ(queue.Size(), 0); // Wait for room >= capacity should always succeed queue.WaitForRoomAtLeast(3); EXPECT_EQ(queue.Size(), 0); // Wait for size >= 0 should always succeed queue.WaitForSizeAtLeast(0); EXPECT_EQ(queue.Size(), 0); } // Tests for gradual large range operations TEST_F(QueueTest, BatchPutAtCapacityWorks) { Queue queue(3); auto producer = queue.CreateProducer(); // Putting exactly capacity worth of items should work std::vector items = {1, 2, 3}; producer.Put(absl::Span(items)); EXPECT_EQ(queue.Size(), 3); } TEST_F(QueueTest, BatchGetAtCapacityWorks) { Queue queue(3); auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); producer.Put(3); // Getting exactly capacity worth of items should work auto result = queue.Get(3); EXPECT_EQ(result.size(), 3); EXPECT_EQ(result[0], 1); EXPECT_EQ(result[1], 2); EXPECT_EQ(result[2], 3); } TEST_F(QueueTest, LargeRangePutGetGradual) { Queue queue(3); // Small capacity auto producer = queue.CreateProducer(); // Put more items than capacity - should work gradually std::vector large_items = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; std::thread put_thread([&producer, &large_items]() { producer.Put(absl::Span(large_items)); }); // Consume items as they become available std::vector consumed; for (int i = 0; i < 10; ++i) { consumed.push_back(queue.Get()); } put_thread.join(); // Verify all items were transferred correctly EXPECT_EQ(consumed.size(), 10); for (int i = 0; i < 10; ++i) { EXPECT_EQ(consumed[i], i + 1); } EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, LargeRangePutMove) { Queue> queue(2); // Very small capacity auto producer = queue.CreateProducer(); // Create large batch of move-only items std::vector> large_items; for (int i = 1; i <= 5; ++i) { large_items.push_back(std::make_unique(i)); } std::thread put_thread([&producer, &large_items]() { producer.Put(absl::Span>(large_items)); }); // Consume items as they become available std::vector> consumed; for (int i = 0; i < 5; ++i) { consumed.push_back(queue.Get()); } put_thread.join(); // Verify all items were transferred correctly EXPECT_EQ(consumed.size(), 5); for (int i = 0; i < 5; ++i) { EXPECT_EQ(*consumed[i], i + 1); } EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, LargeRangeGetGradual) { Queue queue(3); // Small capacity auto producer = queue.CreateProducer(); // Start a thread that will gradually produce items std::thread producer_thread([&producer]() { for (int i = 1; i <= 10; ++i) { producer.Put(i); std::this_thread::sleep_for(std::chrono::milliseconds(1)); } }); // Get more items than capacity - should work gradually auto result = queue.Get(10); producer_thread.join(); // Verify all items were retrieved correctly EXPECT_EQ(result.size(), 10); for (int i = 0; i < 10; ++i) { EXPECT_EQ(result[i], i + 1); } EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, LargeRangePutGetConcurrent) { Queue queue(5); // Medium capacity constexpr int total_items = 100; constexpr int batch_size = 25; auto producer1 = queue.CreateProducer(); auto producer2 = queue.CreateProducer(); std::vector batch1, batch2; for (int i = 0; i < batch_size; ++i) { batch1.push_back(i); batch2.push_back(i + batch_size); } std::atomic items_consumed{0}; std::vector all_consumed; all_consumed.reserve(total_items); // Multiple producers std::thread producer_thread1([&producer1, &batch1]() { producer1.Put(absl::Span(batch1)); producer1.Put(absl::Span(batch1)); // Put twice }); std::thread producer_thread2([&producer2, &batch2]() { producer2.Put(absl::Span(batch2)); producer2.Put(absl::Span(batch2)); // Put twice }); // Consumer getting in large batches std::thread consumer_thread([&queue, &all_consumed, &items_consumed]() { while (items_consumed < total_items) { try { auto batch = queue.Get(std::min(15, total_items - items_consumed)); for (const auto& item : batch) { all_consumed.push_back(item); } items_consumed += batch.size(); } catch (const QueueClosedException&) { break; } } }); producer_thread1.join(); producer_thread2.join(); // Close the queue by destroying producers producer1.Close(); producer2.Close(); consumer_thread.join(); EXPECT_EQ(all_consumed.size(), total_items); EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, GradualOperationsWithQueueClosure) { Queue queue(2); // Very small capacity auto producer = queue.CreateProducer(); std::vector large_batch = {1, 2, 3, 4, 5}; std::atomic exception_caught{false}; std::thread producer_thread([&producer, &large_batch, &exception_caught]() { try { producer.Put(absl::Span(large_batch)); } catch (const QueueClosedException&) { exception_caught = true; } }); // Let producer start putting items std::this_thread::sleep_for(std::chrono::milliseconds(10)); // Consume a couple items queue.Get(); // Should get 1 queue.Get(); // Should get 2 // Close the queue while producer is still trying to put items queue.Close(); producer_thread.join(); EXPECT_TRUE(exception_caught); // Queue might have some items that were put before closure // but we can't predict exactly how many due to timing } // Tests for total put count functionality TEST_F(QueueTest, GetTotalPutCountBasic) { Queue queue(5); EXPECT_EQ(queue.GetTotalPutCount(), 0); { auto producer = queue.CreateProducer(); producer.Put(1); EXPECT_EQ(queue.GetTotalPutCount(), 1); producer.Put(2); producer.Put(3); EXPECT_EQ(queue.GetTotalPutCount(), 3); } // Count should persist after producer destruction EXPECT_EQ(queue.GetTotalPutCount(), 3); // Count should persist after getting items queue.Get(); queue.Get(); EXPECT_EQ(queue.GetTotalPutCount(), 3); } TEST_F(QueueTest, GetTotalPutCountBatch) { Queue queue(10); auto producer = queue.CreateProducer(); std::vector batch1 = {1, 2, 3}; std::vector batch2 = {4, 5}; producer.Put(absl::Span(batch1)); EXPECT_EQ(queue.GetTotalPutCount(), 3); producer.Put(absl::Span(batch2)); EXPECT_EQ(queue.GetTotalPutCount(), 5); // Single put after batch producer.Put(6); EXPECT_EQ(queue.GetTotalPutCount(), 6); } TEST_F(QueueTest, GetTotalPutCountReset) { Queue queue(5); auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); producer.Put(3); EXPECT_EQ(queue.GetTotalPutCount(), 3); // Reset and verify return value EXPECT_EQ(queue.GetTotalPutCount(true), 3); EXPECT_EQ(queue.GetTotalPutCount(), 0); // Add more items producer.Put(4); EXPECT_EQ(queue.GetTotalPutCount(), 1); // Non-reset call should not affect counter EXPECT_EQ(queue.GetTotalPutCount(false), 1); EXPECT_EQ(queue.GetTotalPutCount(), 1); } TEST_F(QueueTest, GetTotalPutCountThreadSafe) { Queue queue(50); // Large capacity to avoid blocking constexpr int items_per_thread = 10; constexpr int num_threads = 2; std::vector threads; std::vector::Producer> producers; // Create producers for each thread for (int t = 0; t < num_threads; ++t) { producers.push_back(queue.CreateProducer()); } for (int t = 0; t < num_threads; ++t) { threads.emplace_back([&producers, t]() { for (int i = 0; i < items_per_thread; ++i) { producers[t].Put(t * items_per_thread + i); } }); } for (auto& thread : threads) { thread.join(); } EXPECT_EQ(queue.GetTotalPutCount(), items_per_thread * num_threads); } TEST_F(QueueTest, GetTotalPutCountBatchThreadSafe) { Queue queue(100); // Large capacity to avoid blocking std::vector threads; std::vector::Producer> producers; // Create producers for each thread for (int t = 0; t < 2; ++t) { producers.push_back(queue.CreateProducer()); } for (int t = 0; t < 2; ++t) { threads.emplace_back([&producers, t]() { std::vector batch; int batch_size = (t + 1) * 5; // 5, 10 items for (int i = 0; i < batch_size; ++i) { batch.push_back(t * 100 + i); } producers[t].Put(absl::Span(batch)); }); } for (auto& thread : threads) { thread.join(); } EXPECT_EQ(queue.GetTotalPutCount(), 15); // 5 + 10 } TEST_F(QueueTest, GetTotalPutCountWithMoveSemantics) { Queue> queue(5); auto producer = queue.CreateProducer(); // Single move put auto ptr1 = std::make_unique(42); producer.Put(std::move(ptr1)); EXPECT_EQ(queue.GetTotalPutCount(), 1); // Batch move put std::vector> batch; for (int i = 0; i < 3; ++i) { batch.push_back(std::make_unique(i)); } producer.Put(absl::Span>(batch)); EXPECT_EQ(queue.GetTotalPutCount(), 4); } TEST_F(QueueTest, GetTotalPutCountEmptyBatch) { Queue queue(5); auto producer = queue.CreateProducer(); std::vector empty_batch; producer.Put(absl::Span(empty_batch)); EXPECT_EQ(queue.GetTotalPutCount(), 0); producer.Put(1); EXPECT_EQ(queue.GetTotalPutCount(), 1); producer.Put(absl::Span(empty_batch)); EXPECT_EQ(queue.GetTotalPutCount(), 1); } // Tests for DROP_NEW overflow behavior TEST_F(QueueTest, DropNewBasicBehavior) { Queue queue(3, Queue::OverflowBehavior::DROP_NEW); auto producer = queue.CreateProducer(); // Fill queue to capacity producer.Put(1); producer.Put(2); producer.Put(3); EXPECT_EQ(queue.Size(), 3); EXPECT_EQ(queue.GetTotalPutCount(), 3); EXPECT_EQ(queue.GetTotalDropCount(), 0); // Additional puts should be dropped producer.Put(4); producer.Put(5); EXPECT_EQ(queue.Size(), 3); EXPECT_EQ(queue.GetTotalPutCount(), 5); EXPECT_EQ(queue.GetTotalDropCount(), 2); // Verify original items are still there EXPECT_EQ(queue.Get(), 1); EXPECT_EQ(queue.Get(), 2); EXPECT_EQ(queue.Get(), 3); } TEST_F(QueueTest, DropNewBatchBehavior) { Queue queue(3, Queue::OverflowBehavior::DROP_NEW); auto producer = queue.CreateProducer(); // Fill queue partially producer.Put(1); EXPECT_EQ(queue.Size(), 1); // Try to put more than capacity allows std::vector large_batch = {2, 3, 4, 5, 6}; producer.Put(absl::Span(large_batch)); // Only first 2 should fit EXPECT_EQ(queue.Size(), 3); EXPECT_EQ(queue.GetTotalPutCount(), 6); // 1 + 5 attempted EXPECT_EQ(queue.GetTotalDropCount(), 3); // 4, 5, 6 were dropped // Verify what's in the queue EXPECT_EQ(queue.Get(), 1); EXPECT_EQ(queue.Get(), 2); EXPECT_EQ(queue.Get(), 3); } TEST_F(QueueTest, DropNewThreadSafety) { Queue queue(5, Queue::OverflowBehavior::DROP_NEW); constexpr int num_threads = 3; constexpr int items_per_thread = 10; std::vector threads; std::vector::Producer> producers; for (int t = 0; t < num_threads; ++t) { producers.push_back(queue.CreateProducer()); } for (int t = 0; t < num_threads; ++t) { threads.emplace_back([&producers, t]() { for (int i = 0; i < items_per_thread; ++i) { producers[t].Put(t * items_per_thread + i); } }); } for (auto& thread : threads) { thread.join(); } // Queue should have at most capacity items EXPECT_LE(queue.Size(), 5); // All puts should be counted EXPECT_EQ(queue.GetTotalPutCount(), num_threads * items_per_thread); // Some items should have been dropped EXPECT_GT(queue.GetTotalDropCount(), 0); // Put count = successful puts + drops EXPECT_EQ(queue.GetTotalPutCount(), queue.Size() + queue.GetTotalDropCount()); } // Tests for KEEP_NEWEST overflow behavior TEST_F(QueueTest, KeepNewestBasicBehavior) { Queue queue(3, Queue::OverflowBehavior::KEEP_NEWEST); auto producer = queue.CreateProducer(); // Fill queue to capacity producer.Put(1); producer.Put(2); producer.Put(3); EXPECT_EQ(queue.Size(), 3); EXPECT_EQ(queue.GetTotalPutCount(), 3); EXPECT_EQ(queue.GetTotalDropCount(), 0); // Additional puts should replace oldest items producer.Put(4); EXPECT_EQ(queue.Size(), 3); EXPECT_EQ(queue.GetTotalPutCount(), 4); EXPECT_EQ(queue.GetTotalDropCount(), 1); producer.Put(5); EXPECT_EQ(queue.Size(), 3); EXPECT_EQ(queue.GetTotalPutCount(), 5); EXPECT_EQ(queue.GetTotalDropCount(), 2); // Verify newest items are kept (3, 4, 5) EXPECT_EQ(queue.Get(), 3); EXPECT_EQ(queue.Get(), 4); EXPECT_EQ(queue.Get(), 5); } TEST_F(QueueTest, KeepNewestBatchBehavior) { Queue queue(3, Queue::OverflowBehavior::KEEP_NEWEST); auto producer = queue.CreateProducer(); // Fill queue producer.Put(1); producer.Put(2); producer.Put(3); // Put large batch that exceeds capacity std::vector large_batch = {4, 5, 6, 7, 8}; producer.Put(absl::Span(large_batch)); // Queue should still have capacity items EXPECT_EQ(queue.Size(), 3); EXPECT_EQ(queue.GetTotalPutCount(), 8); EXPECT_EQ(queue.GetTotalDropCount(), 5); // 1, 2, 3, 4, 5 dropped // Should have the newest 3 items (6, 7, 8) EXPECT_EQ(queue.Get(), 6); EXPECT_EQ(queue.Get(), 7); EXPECT_EQ(queue.Get(), 8); } TEST_F(QueueTest, KeepNewestLargeBatch) { Queue queue(2, Queue::OverflowBehavior::KEEP_NEWEST); auto producer = queue.CreateProducer(); // Put batch larger than capacity std::vector large_batch = {1, 2, 3, 4, 5}; producer.Put(absl::Span(large_batch)); // Should keep only the last 2 items EXPECT_EQ(queue.Size(), 2); EXPECT_EQ(queue.GetTotalPutCount(), 5); EXPECT_EQ(queue.GetTotalDropCount(), 3); EXPECT_EQ(queue.Get(), 4); EXPECT_EQ(queue.Get(), 5); } // Tests for counter functionality TEST_F(QueueTest, GetTotalGetCountBasic) { Queue queue(5); auto producer = queue.CreateProducer(); EXPECT_EQ(queue.GetTotalGetCount(), 0); producer.Put(1); producer.Put(2); producer.Put(3); EXPECT_EQ(queue.GetTotalGetCount(), 0); queue.Get(); EXPECT_EQ(queue.GetTotalGetCount(), 1); queue.Get(); queue.Get(); EXPECT_EQ(queue.GetTotalGetCount(), 3); } TEST_F(QueueTest, GetTotalGetCountBatch) { Queue queue(5); auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); producer.Put(3); producer.Put(4); producer.Put(5); queue.Get(3); EXPECT_EQ(queue.GetTotalGetCount(), 3); queue.Get(); EXPECT_EQ(queue.GetTotalGetCount(), 4); queue.Get(1); EXPECT_EQ(queue.GetTotalGetCount(), 5); } TEST_F(QueueTest, GetTotalGetCountReset) { Queue queue(3); auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); queue.Get(); queue.Get(); EXPECT_EQ(queue.GetTotalGetCount(), 2); EXPECT_EQ(queue.GetTotalGetCount(true), 2); EXPECT_EQ(queue.GetTotalGetCount(), 0); } TEST_F(QueueTest, GetTotalDropCountBasic) { Queue queue(2, Queue::OverflowBehavior::DROP_NEW); auto producer = queue.CreateProducer(); EXPECT_EQ(queue.GetTotalDropCount(), 0); producer.Put(1); producer.Put(2); EXPECT_EQ(queue.GetTotalDropCount(), 0); producer.Put(3); // Should be dropped EXPECT_EQ(queue.GetTotalDropCount(), 1); producer.Put(4); // Should be dropped EXPECT_EQ(queue.GetTotalDropCount(), 2); } TEST_F(QueueTest, GetTotalDropCountKeepNewest) { Queue queue(2, Queue::OverflowBehavior::KEEP_NEWEST); auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); EXPECT_EQ(queue.GetTotalDropCount(), 0); producer.Put(3); // Should drop 1 EXPECT_EQ(queue.GetTotalDropCount(), 1); producer.Put(4); // Should drop 2 EXPECT_EQ(queue.GetTotalDropCount(), 2); } TEST_F(QueueTest, GetTotalDropCountReset) { Queue queue(1, Queue::OverflowBehavior::DROP_NEW); auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); // Dropped producer.Put(3); // Dropped EXPECT_EQ(queue.GetTotalDropCount(), 2); EXPECT_EQ(queue.GetTotalDropCount(true), 2); EXPECT_EQ(queue.GetTotalDropCount(), 0); } // MaybeGet() tests TEST_F(QueueTest, MaybeGetOnEmptyQueue) { Queue queue(5); auto result = queue.MaybeGet(); EXPECT_FALSE(result.has_value()); } TEST_F(QueueTest, MaybeGetOnNonEmptyQueue) { Queue queue(5); { auto producer = queue.CreateProducer(); producer.Put(42); } auto result = queue.MaybeGet(); ASSERT_TRUE(result.has_value()); EXPECT_EQ(*result, 42); EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, MaybeGetMultipleValues) { Queue queue(5); { auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); producer.Put(3); } auto result1 = queue.MaybeGet(); ASSERT_TRUE(result1.has_value()); EXPECT_EQ(*result1, 1); auto result2 = queue.MaybeGet(); ASSERT_TRUE(result2.has_value()); EXPECT_EQ(*result2, 2); auto result3 = queue.MaybeGet(); ASSERT_TRUE(result3.has_value()); EXPECT_EQ(*result3, 3); auto result4 = queue.MaybeGet(); EXPECT_FALSE(result4.has_value()); } TEST_F(QueueTest, MaybeGetWithMoveOnlyType) { Queue> queue(5); { auto producer = queue.CreateProducer(); producer.Put(std::make_unique(42)); } auto result = queue.MaybeGet(); ASSERT_TRUE(result.has_value()); ASSERT_NE(*result, nullptr); EXPECT_EQ(**result, 42); } TEST_F(QueueTest, MaybeGetUpdatesGetCount) { Queue queue(5); { auto producer = queue.CreateProducer(); producer.Put(1); producer.Put(2); } EXPECT_EQ(queue.GetTotalGetCount(), 0); queue.MaybeGet(); EXPECT_EQ(queue.GetTotalGetCount(), 1); queue.MaybeGet(); EXPECT_EQ(queue.GetTotalGetCount(), 2); queue.MaybeGet(); // Empty queue. EXPECT_EQ(queue.GetTotalGetCount(), 2); // Count not incremented. } // Tests for stop_token cancellation TEST_F(QueueTest, StopTokenCancelsPut) { Queue queue(2); auto producer = queue.CreateProducer(); // Fill the queue producer.Put(1); producer.Put(2); std::stop_source stop_source; stop_source.request_stop(); // Should immediately throw without blocking since token is already stopped EXPECT_THROW(producer.Put(3, stop_source.get_token()), QueueRequestCancelled); EXPECT_EQ(queue.Size(), 2); } TEST_F(QueueTest, StopTokenCancelsGet) { Queue queue(5); auto producer = queue.CreateProducer(); std::stop_source stop_source; stop_source.request_stop(); // Should immediately throw without blocking since token is already stopped EXPECT_THROW(queue.Get(stop_source.get_token()), QueueRequestCancelled); EXPECT_EQ(queue.Size(), 0); } TEST_F(QueueTest, StopTokenCancelsBatchPut) { Queue queue(2); auto producer = queue.CreateProducer(); // Fill queue completely producer.Put(1); producer.Put(2); std::stop_source stop_source; stop_source.request_stop(); // Should immediately throw without blocking since token is already stopped std::vector items = {3, 4, 5}; EXPECT_THROW( producer.Put(absl::Span(items), stop_source.get_token()), QueueRequestCancelled); } TEST_F(QueueTest, StopTokenCancelsBatchGet) { Queue queue(5); auto producer = queue.CreateProducer(); std::stop_source stop_source; stop_source.request_stop(); // Should immediately throw without blocking since token is already stopped EXPECT_THROW(queue.Get(10, stop_source.get_token()), QueueRequestCancelled); } TEST_F(QueueTest, StopTokenCancelsWaitForRoomAtLeast) { Queue queue(3); auto producer = queue.CreateProducer(); // Fill queue producer.Put(1); producer.Put(2); producer.Put(3); std::stop_source stop_source; stop_source.request_stop(); // Should immediately throw without blocking since token is already stopped EXPECT_THROW(queue.WaitForRoomAtLeast(2, stop_source.get_token()), QueueRequestCancelled); } TEST_F(QueueTest, StopTokenCancelsWaitForSizeAtLeast) { Queue queue(5); auto producer = queue.CreateProducer(); std::stop_source stop_source; stop_source.request_stop(); // Should immediately throw without blocking since token is already stopped EXPECT_THROW(queue.WaitForSizeAtLeast(3, stop_source.get_token()), QueueRequestCancelled); } } // namespace lczero ================================================ FILE: csrc/utils/stream_shuffler.cc ================================================ #include "utils/stream_shuffler.h" namespace lczero { namespace training { void StreamShuffler::SetUpperBound(size_t upper_bound) { assert(upper_bound >= upper_bound_); stream_size_ += upper_bound - upper_bound_; while (upper_bound_ < upper_bound) { if (buckets_.empty() || buckets_.back().GetRemainingCapacity() == 0) { buckets_.emplace_back(upper_bound_, bucket_size_); } upper_bound_ = std::min( upper_bound, upper_bound_ + buckets_.back().GetRemainingCapacity()); buckets_.back().Extend(upper_bound_); } } void StreamShuffler::SetLowerBound(size_t lower_bound) { assert(lower_bound >= lower_bound_); lower_bound_ = lower_bound; if (lower_bound >= upper_bound_) { upper_bound_ = lower_bound; stream_size_ = 0; buckets_.clear(); return; } while (!buckets_.empty() && buckets_.front().upper_bound() <= lower_bound_) { stream_size_ -= buckets_.front().size(); buckets_.pop_front(); } if (!buckets_.empty()) { auto old_size = buckets_.front().size(); buckets_.front().DeclareLowerBound(lower_bound_); stream_size_ -= old_size - buckets_.front().size(); } } std::optional StreamShuffler::GetNextItem() { auto try_fetch = [&]() -> size_t { size_t item_idx = absl::Uniform(gen_, size_t{0}, stream_size_); --stream_size_; for (auto& bucket : buckets_) { if (item_idx < bucket.size()) return bucket.Fetch(item_idx); item_idx -= bucket.size(); } throw std::logic_error("StreamShuffler: item index out of bounds"); }; while (stream_size_ > 0) { if (auto item = try_fetch(); item >= lower_bound_) return item; } return std::nullopt; } void StreamShuffler::Reset(size_t lower_bound, size_t upper_bound) { // Reset all internal state buckets_.clear(); stream_size_ = 0; upper_bound_ = lower_bound; lower_bound_ = lower_bound; // Establish the bounds, which will build the buckets with fresh data if (upper_bound > lower_bound) { SetUpperBound(upper_bound); } } StreamShuffler::Bucket::Bucket(size_t lower_bound, size_t capacity) : upper_bound_(lower_bound), items_(capacity) {} size_t StreamShuffler::Bucket::GetRemainingCapacity() const { return items_.size() - items_count_; } void StreamShuffler::Bucket::Extend(size_t new_upper_bound) { assert(new_upper_bound >= upper_bound_); const size_t increase = new_upper_bound - upper_bound_; assert(increase <= GetRemainingCapacity()); std::iota(items_.begin() + items_count_, items_.begin() + items_count_ + increase, upper_bound_); items_count_ += increase; upper_bound_ = new_upper_bound; } size_t StreamShuffler::Bucket::Fetch(size_t item_idx) { assert(item_idx < items_count_); size_t item = items_[item_idx]; std::swap(items_[item_idx], items_[--items_count_]); return item; } void DeclareLowerBound(size_t new_lower_bound); void StreamShuffler::Bucket::DeclareLowerBound(size_t new_lower_bound) { if (upper_bound_ - new_lower_bound < 2 * items_count_) return; // If the bucket has much more items that the allowed range, there are many // items out of the range. It makes sense to sort and remove them. std::sort(items_.begin(), items_.begin() + items_count_, std::greater()); // Find the first item that is under the new lower bound. auto it = std::upper_bound(items_.begin(), items_.begin() + items_count_, new_lower_bound, std::greater()); items_count_ = it - items_.begin(); } } // namespace training } // namespace lczero ================================================ FILE: csrc/utils/stream_shuffler.h ================================================ #pragma once #include #include #include #include #include #include namespace lczero { namespace training { // Returns a number between [lower_bound, upper_bound) in shuffled order. // Both bounds can be changed at any time, and the stream will adapt // accordingly. Not thread-safe. class StreamShuffler { public: // Sets the upper bound (exclusive). Can only be increased. void SetUpperBound(size_t upper_bound); // Sets the lower bound (inclusive). Can only be increased. void SetLowerBound(size_t lower_bound); // Sets the bucket size for internal storage optimization. void SetBucketSize(size_t bucket_size) { bucket_size_ = bucket_size; } // Returns the next item in shuffled order, or nullopt if exhausted. std::optional GetNextItem(); // Resets the shuffler to restart iteration with specified bounds. void Reset(size_t lower_bound, size_t upper_bound); private: class Bucket { public: Bucket(size_t lower_bound, size_t capacity); size_t GetRemainingCapacity() const; void Extend(size_t new_upper_bound); size_t Fetch(size_t item_idx); void DeclareLowerBound(size_t new_lower_bound); size_t upper_bound() const { return upper_bound_; } size_t size() const { return items_count_; } private: size_t upper_bound_ = 0; size_t items_count_ = 0; absl::FixedArray items_; }; absl::BitGen gen_; std::deque buckets_; size_t stream_size_ = 0; size_t upper_bound_ = 0; size_t lower_bound_ = 0; size_t bucket_size_ = 524288; }; } // namespace training } // namespace lczero ================================================ FILE: csrc/utils/stream_shuffler_test.cc ================================================ #include "utils/stream_shuffler.h" #include #include #include #include #include namespace lczero { namespace training { class StreamShufflerTest : public ::testing::Test { protected: void SetUp() override { shuffler_.SetBucketSize(4); } StreamShuffler shuffler_; }; TEST_F(StreamShufflerTest, EmptyRangeReturnsNullopt) { shuffler_.SetUpperBound(10); shuffler_.SetLowerBound(10); EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt); } TEST_F(StreamShufflerTest, SingleItemRange) { shuffler_.SetUpperBound(1); shuffler_.SetLowerBound(0); auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); EXPECT_EQ(item.value(), 0); EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt); } TEST_F(StreamShufflerTest, BasicRangeGeneration) { shuffler_.SetUpperBound(5); shuffler_.SetLowerBound(0); std::set received; for (int i = 0; i < 5; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); EXPECT_GE(item.value(), 0); EXPECT_LT(item.value(), 5); EXPECT_TRUE(received.insert(item.value()).second); } EXPECT_EQ(received.size(), 5); EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt); } TEST_F(StreamShufflerTest, HeadAdvancesByBucketMultiples) { shuffler_.SetUpperBound(4); shuffler_.SetLowerBound(0); std::set received; for (int i = 0; i < 4; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); received.insert(item.value()); } EXPECT_EQ(received.size(), 4); shuffler_.SetUpperBound(8); for (int i = 0; i < 4; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); EXPECT_GE(item.value(), 0); EXPECT_LT(item.value(), 8); EXPECT_TRUE(received.insert(item.value()).second); } EXPECT_EQ(received.size(), 8); EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt); } TEST_F(StreamShufflerTest, HeadAdvancesByNonMultiples) { shuffler_.SetUpperBound(3); shuffler_.SetLowerBound(0); std::set received; for (int i = 0; i < 3; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); received.insert(item.value()); } shuffler_.SetUpperBound(7); for (int i = 0; i < 4; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); EXPECT_GE(item.value(), 0); EXPECT_LT(item.value(), 7); EXPECT_TRUE(received.insert(item.value()).second); } EXPECT_EQ(received.size(), 7); EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt); } TEST_F(StreamShufflerTest, TailAdvancesByBucketMultiples) { shuffler_.SetUpperBound(12); shuffler_.SetLowerBound(0); std::set all_received; for (int i = 0; i < 4; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); EXPECT_GE(item.value(), 0); EXPECT_LT(item.value(), 12); EXPECT_TRUE(all_received.insert(item.value()).second); } shuffler_.SetLowerBound(4); std::optional item; while ((item = shuffler_.GetNextItem()).has_value()) { EXPECT_GE(item.value(), 4); EXPECT_LT(item.value(), 12); EXPECT_TRUE(all_received.insert(item.value()).second); } // Verify all items in range [4, 12) were eventually fetched for (size_t i = 4; i < 12; ++i) { EXPECT_TRUE(all_received.count(i) > 0) << "Item " << i << " was never fetched"; } } TEST_F(StreamShufflerTest, TailAdvancesByNonMultiples) { shuffler_.SetUpperBound(10); shuffler_.SetLowerBound(0); std::set all_received; for (int i = 0; i < 3; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); EXPECT_GE(item.value(), 0); EXPECT_LT(item.value(), 10); EXPECT_TRUE(all_received.insert(item.value()).second); } shuffler_.SetLowerBound(3); std::optional item; while ((item = shuffler_.GetNextItem()).has_value()) { EXPECT_GE(item.value(), 3); EXPECT_LT(item.value(), 10); EXPECT_TRUE(all_received.insert(item.value()).second); } // Verify all items in range [3, 10) were eventually fetched for (size_t i = 3; i < 10; ++i) { EXPECT_TRUE(all_received.count(i) > 0) << "Item " << i << " was never fetched"; } } TEST_F(StreamShufflerTest, BothBoundsSlideSimultaneously) { shuffler_.SetUpperBound(10); shuffler_.SetLowerBound(0); std::set all_received; for (int i = 0; i < 5; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); EXPECT_GE(item.value(), 0); EXPECT_LT(item.value(), 10); EXPECT_TRUE(all_received.insert(item.value()).second); } shuffler_.SetUpperBound(15); shuffler_.SetLowerBound(5); std::optional item; while ((item = shuffler_.GetNextItem()).has_value()) { EXPECT_GE(item.value(), 5); EXPECT_LT(item.value(), 15); EXPECT_TRUE(all_received.insert(item.value()).second); } // Verify all items in range [5, 15) were eventually fetched for (size_t i = 5; i < 15; ++i) { EXPECT_TRUE(all_received.count(i) > 0) << "Item " << i << " was never fetched"; } } TEST_F(StreamShufflerTest, ComplexSlidingWindow) { std::set all_received; shuffler_.SetUpperBound(6); shuffler_.SetLowerBound(0); for (int i = 0; i < 3; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); all_received.insert(item.value()); } shuffler_.SetUpperBound(11); for (int i = 0; i < 2; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); all_received.insert(item.value()); } shuffler_.SetLowerBound(2); shuffler_.SetUpperBound(14); std::set final_received; std::optional item; while ((item = shuffler_.GetNextItem()).has_value()) { EXPECT_GE(item.value(), 2); EXPECT_LT(item.value(), 14); EXPECT_TRUE(final_received.insert(item.value()).second); } for (const auto& val : final_received) { EXPECT_GE(val, 2); EXPECT_LT(val, 14); } } TEST_F(StreamShufflerTest, UniquenessAcrossMultipleBuckets) { shuffler_.SetUpperBound(20); shuffler_.SetLowerBound(0); std::set received; std::optional item; while ((item = shuffler_.GetNextItem()).has_value()) { EXPECT_GE(item.value(), 0); EXPECT_LT(item.value(), 20); EXPECT_TRUE(received.insert(item.value()).second); } EXPECT_EQ(received.size(), 20); } TEST_F(StreamShufflerTest, TailCatchesUpToHead) { shuffler_.SetUpperBound(8); shuffler_.SetLowerBound(0); for (int i = 0; i < 3; ++i) { auto item = shuffler_.GetNextItem(); ASSERT_TRUE(item.has_value()); } shuffler_.SetLowerBound(8); EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt); } TEST_F(StreamShufflerTest, ResetAllowsIterationRestart) { shuffler_.SetUpperBound(5); shuffler_.SetLowerBound(0); // Exhaust all items absl::flat_hash_set first_round; std::optional item; while ((item = shuffler_.GetNextItem()).has_value()) { first_round.insert(item.value()); } // Should have gotten all 5 items EXPECT_EQ(first_round.size(), 5); // Shuffler should be exhausted EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt); // Reset the shuffler shuffler_.Reset(0, 5); // Should be able to get items again absl::flat_hash_set second_round; int count = 0; while ((item = shuffler_.GetNextItem()).has_value() && count < 10) { second_round.insert(item.value()); count++; } // Should get all items again EXPECT_EQ(second_round.size(), 5); // Both rounds should contain the same set of items EXPECT_EQ(first_round, second_round); } } // namespace training } // namespace lczero ================================================ FILE: csrc/utils/tensor.h ================================================ #pragma once #include #include #include #include #include "absl/algorithm/container.h" #include "absl/container/fixed_array.h" #include "absl/types/span.h" namespace lczero { // Class that holds tensor which will be exposed through pybind11. class TensorBase { public: virtual ~TensorBase() = default; virtual void* data() = 0; virtual const void* data() const = 0; virtual const std::vector& shape() const = 0; virtual const std::vector& strides() const = 0; virtual size_t element_size() const = 0; virtual std::string py_format() const = 0; }; template class TypedTensor : public TensorBase { public: TypedTensor(std::initializer_list shape) : data_(CalculateTotalSize(shape)), shape_(shape.begin(), shape.end()) { // Calculate strides in row-major order (in bytes). strides_.resize(shape_.size()); size_t total_size = 1; for (int i = static_cast(shape_.size()) - 1; i >= 0; --i) { strides_[i] = total_size * sizeof(T); total_size *= shape_[i]; } } void* data() override { return data_.data(); } const void* data() const override { return data_.data(); } const std::vector& shape() const override { return shape_; } const std::vector& strides() const override { return strides_; } size_t element_size() const override { return sizeof(T); } std::string py_format() const override { if constexpr (std::is_same_v) { return "f"; } else if constexpr (std::is_same_v) { return "d"; } else if constexpr (std::is_same_v) { return "i"; } else if constexpr (std::is_same_v) { return "q"; } else { static_assert(std::is_same_v, "Unsupported tensor type"); } } T& operator[](absl::Span dims) { if (dims.size() != shape_.size()) { throw std::invalid_argument( "Number of dimensions must match tensor rank"); } return data_[CalculateOffset(dims)]; } const T& operator[](absl::Span dims) const { if (dims.size() != shape_.size()) { throw std::invalid_argument( "Number of dimensions must match tensor rank"); } return data_[CalculateOffset(dims)]; } absl::Span slice(absl::Span dims) { if (dims.size() > shape_.size()) { throw std::invalid_argument( "Number of dimensions cannot exceed tensor rank"); } return absl::Span(data_.data() + CalculateOffset(dims), CalculateSliceSize(dims.size())); } absl::Span slice(absl::Span dims) const { if (dims.size() > shape_.size()) { throw std::invalid_argument( "Number of dimensions cannot exceed tensor rank"); } return absl::Span(data_.data() + CalculateOffset(dims), CalculateSliceSize(dims.size())); } private: size_t CalculateOffset(absl::Span dims) const { return absl::c_inner_product( dims, strides_, size_t{0}, std::plus<>{}, [](ssize_t dim, ssize_t stride) { return dim * stride / sizeof(T); }); } size_t CalculateSliceSize(size_t dims_size) const { return absl::c_accumulate(absl::MakeConstSpan(shape_).subspan(dims_size), size_t{1}, std::multiplies{}); } static size_t CalculateTotalSize(std::initializer_list shape) { return absl::c_accumulate(shape, size_t{1}, std::multiplies{}); } absl::FixedArray data_; std::vector shape_; std::vector strides_; }; using TensorTuple = std::vector>; } // namespace lczero ================================================ FILE: csrc/utils/tensor_test.cc ================================================ // ABOUTME: Unit tests for tensor classes and their data access methods. // ABOUTME: Tests construction, element access, slicing, and error conditions. #include "utils/tensor.h" #include namespace lczero { namespace { TEST(TypedTensorTest, ConstructorAndBasicProperties) { TypedTensor tensor({2, 3, 4}); // Check shape EXPECT_EQ(tensor.shape().size(), 3); EXPECT_EQ(tensor.shape()[0], 2); EXPECT_EQ(tensor.shape()[1], 3); EXPECT_EQ(tensor.shape()[2], 4); // Check strides (in bytes, row-major order) EXPECT_EQ(tensor.strides().size(), 3); EXPECT_EQ(tensor.strides()[0], 12 * sizeof(float)); // 3 * 4 elements EXPECT_EQ(tensor.strides()[1], 4 * sizeof(float)); // 4 elements EXPECT_EQ(tensor.strides()[2], 1 * sizeof(float)); // 1 element // Check element size EXPECT_EQ(tensor.element_size(), sizeof(float)); // Check py_format EXPECT_EQ(tensor.py_format(), "f"); // Check data pointer is valid EXPECT_NE(tensor.data(), nullptr); } TEST(TypedTensorTest, PyFormatForDifferentTypes) { TypedTensor float_tensor({2}); EXPECT_EQ(float_tensor.py_format(), "f"); TypedTensor double_tensor({2}); EXPECT_EQ(double_tensor.py_format(), "d"); TypedTensor int32_tensor({2}); EXPECT_EQ(int32_tensor.py_format(), "i"); TypedTensor int64_tensor({2}); EXPECT_EQ(int64_tensor.py_format(), "q"); } TEST(TypedTensorTest, ElementAccess) { TypedTensor tensor({2, 3}); // Set some values tensor[{0, 0}] = 10; tensor[{0, 1}] = 11; tensor[{0, 2}] = 12; tensor[{1, 0}] = 20; tensor[{1, 1}] = 21; tensor[{1, 2}] = 22; // Check values EXPECT_EQ((tensor[{0, 0}]), 10); EXPECT_EQ((tensor[{0, 1}]), 11); EXPECT_EQ((tensor[{0, 2}]), 12); EXPECT_EQ((tensor[{1, 0}]), 20); EXPECT_EQ((tensor[{1, 1}]), 21); EXPECT_EQ((tensor[{1, 2}]), 22); } TEST(TypedTensorTest, ConstElementAccess) { TypedTensor tensor({2, 2}); tensor[{0, 0}] = 1; tensor[{0, 1}] = 2; tensor[{1, 0}] = 3; tensor[{1, 1}] = 4; const auto& const_tensor = tensor; EXPECT_EQ((const_tensor[{0, 0}]), 1); EXPECT_EQ((const_tensor[{0, 1}]), 2); EXPECT_EQ((const_tensor[{1, 0}]), 3); EXPECT_EQ((const_tensor[{1, 1}]), 4); } TEST(TypedTensorTest, SliceAccess) { TypedTensor tensor({2, 3, 4}); // Fill with test data for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { for (int k = 0; k < 4; ++k) { tensor[{i, j, k}] = i * 100 + j * 10 + k; } } } // Test 1D slice (fix first dimension) auto slice1d = tensor.slice({1}); EXPECT_EQ(slice1d.size(), 12); // 3 * 4 elements EXPECT_EQ(slice1d[0], 100); // tensor[{1, 0, 0}] EXPECT_EQ(slice1d[4], 110); // tensor[{1, 1, 0}] // Test 2D slice (fix first two dimensions) auto slice2d = tensor.slice({0, 1}); EXPECT_EQ(slice2d.size(), 4); // 4 elements EXPECT_EQ(slice2d[0], 10); // tensor[{0, 1, 0}] EXPECT_EQ(slice2d[1], 11); // tensor[{0, 1, 1}] EXPECT_EQ(slice2d[2], 12); // tensor[{0, 1, 2}] EXPECT_EQ(slice2d[3], 13); // tensor[{0, 1, 3}] // Test full tensor slice (no dimensions fixed) auto full_slice = tensor.slice({}); EXPECT_EQ(full_slice.size(), 24); // 2 * 3 * 4 elements } TEST(TypedTensorTest, ConstSliceAccess) { TypedTensor tensor({2, 2}); tensor[{0, 0}] = 1; tensor[{0, 1}] = 2; tensor[{1, 0}] = 3; tensor[{1, 1}] = 4; const auto& const_tensor = tensor; auto slice = const_tensor.slice({0}); EXPECT_EQ(slice.size(), 2); EXPECT_EQ(slice[0], 1); EXPECT_EQ(slice[1], 2); } TEST(TypedTensorTest, ElementAccessWrongDimensions) { TypedTensor tensor({2, 3}); EXPECT_THROW((tensor[{0}]), std::invalid_argument); EXPECT_THROW((tensor[{0, 1, 2}]), std::invalid_argument); } TEST(TypedTensorTest, SliceAccessTooManyDimensions) { TypedTensor tensor({2, 3}); EXPECT_THROW((tensor.slice({0, 1, 2})), std::invalid_argument); } TEST(TypedTensorTest, OneDimensionalTensor) { TypedTensor tensor({5}); EXPECT_EQ(tensor.shape().size(), 1); EXPECT_EQ(tensor.shape()[0], 5); EXPECT_EQ(tensor.strides()[0], sizeof(float)); tensor[{0}] = 1.0f; tensor[{4}] = 5.0f; EXPECT_EQ((tensor[{0}]), 1.0f); EXPECT_EQ((tensor[{4}]), 5.0f); auto slice = tensor.slice({}); EXPECT_EQ(slice.size(), 5); } } // namespace } // namespace lczero ================================================ FILE: csrc/utils/thread_pool.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include "absl/functional/any_invocable.h" #include "absl/synchronization/mutex.h" namespace lczero { struct ThreadPoolOptions { // If true, starts new thread when task is enqueued and no threads are idle. bool grow_automatically = false; }; class ThreadPool { public: ThreadPool(size_t initial_threads = 0, const ThreadPoolOptions& options = ThreadPoolOptions(), std::stop_source stop_source = std::stop_source()); // Blocks until all tasks are completed and threads are joined. ~ThreadPool(); // Returns the stop_token for this thread pool. std::stop_token stop_token() const; // Enqueues a task for execution and returns a std::future. // If the provided function accepts a std::stop_token as its first argument, // one will be passed to it from the thread pool's stop source. template auto Enqueue(F&& f, Args&&... args) { // This lambda captures the common queuing logic. auto enqueue_common = [&](auto&& task_to_enqueue, auto&& future_to_return) { { absl::MutexLock lock(&mutex_); running_tasks_ += 1; while (options_.grow_automatically && running_tasks_ >= threads_.size()) { StartWorkerThread(); } pending_tasks_.emplace_back( [task = std::move(task_to_enqueue)]() mutable { task(); }); work_available_.Signal(); } return future_to_return; }; if constexpr (std::is_invocable_v) { using ReturnType = std::invoke_result_t; std::packaged_task task( std::bind(std::forward(f), stop_source_.get_token(), std::forward(args)...)); std::future future = task.get_future(); return enqueue_common(std::move(task), std::move(future)); } else { using ReturnType = std::invoke_result_t; std::packaged_task task( std::bind(std::forward(f), std::forward(args)...)); std::future future = task.get_future(); return enqueue_common(std::move(task), std::move(future)); } } // Waits for all tasks to complete. void WaitAll(); // Waits for at least one thread to become available, i.e. no tasks are // pending, and number of running tasks is less than the number of threads. void WaitForAvailableThread(); // Waits until the number of queued but not yet started tasks is below // the specified threshold. void WaitForPendingTasksBelow(size_t threshold); // Number of tasks that are not yet started. size_t num_pending_tasks() const; // Number of tasks that are currently running. size_t num_running_tasks() const; // Number of worker threads (busy or not). size_t num_threads() const; // Signal workers to terminate and join all threads. void Shutdown(); private: void WorkerLoop(); void WorkerEntryPoint(); void StartWorkerThread() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); bool AllTasksCompletedCond() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return pending_tasks_.empty() && running_tasks_ == 0; } bool ThreadAvailableCond() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return pending_tasks_.empty() && running_tasks_ < threads_.size(); } ThreadPool(const ThreadPool&) = delete; ThreadPool& operator=(const ThreadPool&) = delete; ThreadPool(ThreadPool&&) = delete; ThreadPool& operator=(ThreadPool&&) = delete; ThreadPoolOptions options_; mutable absl::Mutex mutex_; absl::CondVar work_available_; absl::CondVar work_done_; std::stop_source stop_source_; std::vector threads_ ABSL_GUARDED_BY(mutex_); std::deque> pending_tasks_ ABSL_GUARDED_BY(mutex_); size_t running_tasks_ ABSL_GUARDED_BY(mutex_) = 0; }; inline ThreadPool::ThreadPool(size_t initial_threads, const ThreadPoolOptions& options, std::stop_source stop_source) : options_(options), stop_source_(std::move(stop_source)) { absl::MutexLock lock(&mutex_); for (size_t i = 0; i < initial_threads; ++i) { StartWorkerThread(); } } inline ThreadPool::~ThreadPool() { Shutdown(); } inline std::stop_token ThreadPool::stop_token() const { return stop_source_.get_token(); } inline void ThreadPool::WorkerLoop() { while (true) { absl::AnyInvocable task; { absl::MutexLock lock(&mutex_); while (!stop_source_.stop_requested() && pending_tasks_.empty()) { work_available_.Wait(&mutex_); } if (stop_source_.stop_requested() && pending_tasks_.empty()) return; task = std::move(pending_tasks_.front()); pending_tasks_.pop_front(); } std::move(task)(); { absl::MutexLock lock(&mutex_); running_tasks_ -= 1; work_done_.SignalAll(); if (!pending_tasks_.empty()) work_available_.Signal(); } } } inline void ThreadPool::WorkerEntryPoint() { try { WorkerLoop(); } catch (const std::exception& exception) { std::cerr << "ThreadPool worker exited due to uncaught exception: " << exception.what() << std::endl; throw; } catch (...) { std::cerr << "ThreadPool worker exited due to unknown exception." << std::endl; throw; } } inline void ThreadPool::WaitAll() { absl::MutexLock lock(&mutex_); while (!AllTasksCompletedCond()) { work_done_.Wait(&mutex_); } } inline void ThreadPool::WaitForAvailableThread() { absl::MutexLock lock(&mutex_); while (!ThreadAvailableCond()) { work_done_.Wait(&mutex_); } } inline void ThreadPool::WaitForPendingTasksBelow(size_t threshold) { absl::MutexLock lock(&mutex_); while (pending_tasks_.size() >= threshold) { work_done_.Wait(&mutex_); } } inline void ThreadPool::StartWorkerThread() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { threads_.emplace_back(&ThreadPool::WorkerEntryPoint, this); } inline size_t ThreadPool::num_pending_tasks() const { absl::MutexLock lock(&mutex_); return pending_tasks_.size(); } inline size_t ThreadPool::num_running_tasks() const { absl::MutexLock lock(&mutex_); return std::max(running_tasks_, threads_.size()); } inline size_t ThreadPool::num_threads() const { absl::MutexLock lock(&mutex_); return threads_.size(); } inline void ThreadPool::Shutdown() { { absl::MutexLock lock(&mutex_); if (!stop_source_.stop_requested()) { stop_source_.request_stop(); } work_available_.SignalAll(); work_done_.SignalAll(); } threads_.clear(); } } // namespace lczero ================================================ FILE: csrc/utils/training_data_printer.cc ================================================ #include "utils/training_data_printer.h" #include #include #include #include "chess/board.h" #include "neural/decoder.h" #include "trainingdata/reader.h" namespace lczero { namespace training { void PrintFloatArray(const float* data, size_t size, absl::string_view name, int64_t per_line) { per_line = std::max(1, per_line); std::cout << " " << name << ":\n"; for (size_t i = 0; i < size; ++i) { if (i % per_line == 0) { std::cout << " [" << absl::StrFormat("%4zu", i) << "]: "; } std::cout << absl::StrFormat("% .6g", data[i]); if ((i + 1) % per_line == 0 || i + 1 == size) { std::cout << "\n"; } else { std::cout << ", "; } } } void PrintUint64Array(const uint64_t* data, size_t size, absl::string_view name, int64_t per_line) { per_line = std::max(1, per_line); std::cout << " " << name << ":\n"; for (size_t i = 0; i < size; ++i) { if (i % per_line == 0) { std::cout << " [" << absl::StrFormat("%3zu", i) << "]: "; } std::cout << absl::StrFormat("0x%016x", data[i]); if ((i + 1) % per_line == 0 || i + 1 == size) { std::cout << "\n"; } else { std::cout << ", "; } } } std::string DecodeInvarianceInfo(uint8_t invariance_info) { return absl::StrFormat( "flip=%d, mirror=%d, transpose=%d, best_move_proven=%d, " "max_length=%d, adjudicated=%d, rescorer_deleted=%d, side_to_move=%d", invariance_info & 0x1, (invariance_info >> 1) & 0x1, (invariance_info >> 2) & 0x1, (invariance_info >> 3) & 0x1, (invariance_info >> 4) & 0x1, (invariance_info >> 5) & 0x1, (invariance_info >> 6) & 0x1, (invariance_info >> 7) & 0x1); } std::string TrainingDataToFen(const FrameType& entry) { InputPlanes planes = PlanesFromTrainingData(entry); ChessBoard board; int rule50 = 0; int gameply = 0; PopulateBoard( static_cast(entry.input_format), planes, &board, &rule50, &gameply); std::string fen = BoardToFen(board); fen += " " + std::to_string(rule50); fen += " " + std::to_string((gameply / 2) + 1); return fen; } void PrintTrainingDataEntry(const FrameType& entry, absl::string_view header_text, int64_t float_per_line, int64_t plane_per_line) { std::cout << header_text << "\n"; std::cout << " FEN: " << TrainingDataToFen(entry) << "\n"; std::cout << " version: " << entry.version << "\n"; std::cout << " input_format: " << entry.input_format << "\n"; std::cout << " castling_us_ooo: " << static_cast(entry.castling_us_ooo) << "\n"; std::cout << " castling_us_oo: " << static_cast(entry.castling_us_oo) << "\n"; std::cout << " castling_them_ooo: " << static_cast(entry.castling_them_ooo) << "\n"; std::cout << " castling_them_oo: " << static_cast(entry.castling_them_oo) << "\n"; std::cout << " side_to_move_or_enpassant: " << static_cast(entry.side_to_move_or_enpassant) << "\n"; std::cout << " rule50_count: " << static_cast(entry.rule50_count) << "\n"; std::cout << " invariance_info: " << static_cast(entry.invariance_info) << " (" << DecodeInvarianceInfo(entry.invariance_info) << ")\n"; std::cout << " dummy: " << static_cast(entry.dummy) << "\n"; std::cout << " root_q: " << entry.root_q << "\n"; std::cout << " best_q: " << entry.best_q << "\n"; std::cout << " root_d: " << entry.root_d << "\n"; std::cout << " best_d: " << entry.best_d << "\n"; std::cout << " root_m: " << entry.root_m << "\n"; std::cout << " best_m: " << entry.best_m << "\n"; std::cout << " plies_left: " << entry.plies_left << "\n"; std::cout << " result_q: " << entry.result_q << "\n"; std::cout << " result_d: " << entry.result_d << "\n"; std::cout << " played_q: " << entry.played_q << "\n"; std::cout << " played_d: " << entry.played_d << "\n"; std::cout << " played_m: " << entry.played_m << "\n"; std::cout << " orig_q: " << entry.orig_q << "\n"; std::cout << " orig_d: " << entry.orig_d << "\n"; std::cout << " orig_m: " << entry.orig_m << "\n"; std::cout << " visits: " << entry.visits << "\n"; std::cout << " played_idx: " << entry.played_idx << "\n"; std::cout << " best_idx: " << entry.best_idx << "\n"; std::cout << " policy_kld: " << entry.policy_kld << "\n"; PrintFloatArray(entry.probabilities, std::size(entry.probabilities), "probabilities", float_per_line); PrintUint64Array(entry.planes, std::size(entry.planes), "planes", plane_per_line); std::cout << std::flush; } } // namespace training } // namespace lczero ================================================ FILE: csrc/utils/training_data_printer.h ================================================ #ifndef LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_ #define LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_ #include #include #include #include #include "loader/frame_type.h" #include "trainingdata/trainingdata_v6.h" namespace lczero { namespace training { // Prints a float array with configurable number of values per line. void PrintFloatArray(const float* data, size_t size, absl::string_view name, int64_t per_line); // Prints a uint64 array with configurable number of values per line. void PrintUint64Array(const uint64_t* data, size_t size, absl::string_view name, int64_t per_line); // Decodes the invariance_info byte into a human-readable string. std::string DecodeInvarianceInfo(uint8_t invariance_info); // Converts a V6TrainingData entry to FEN (Forsyth-Edwards Notation). std::string TrainingDataToFen(const V6TrainingData& entry); // Prints a V6TrainingData entry with a custom header and formatting options. void PrintTrainingDataEntry(const FrameType& entry, absl::string_view header_text, int64_t float_per_line, int64_t plane_per_line); } // namespace training } // namespace lczero #endif // LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_ ================================================ FILE: docs/README.md ================================================ # Running "new" training pipeline Note that the code is still in active development, so things change a lot. The current document was last updated on 2025-11-30. ## Building The new training pipeline is located in `src/` (Python part) and `csrc/` (C++ part). * Python code uses `uv`. Install it as described in the [uv installation guide](https://docs.astral.sh/uv/#installation). * Many steps are run via `just`. `just` is just a glorified shell script runner. So either look into [`Justfile`](../justfile) or install `just` as described in the [just installation guide](https://github.com/casey/just#installation). * You'll need a recent protobuf compiler (`protoc`). * You'll need a C++ compiler. In this example we use `clang`. ```bash cd uv python install 3.12 uv venv uv sync uv pip install meson ruff git submodule update --init --recursive CXX=clang++ CC=clang uv run meson setup build/release\ --buildtype=release --native-file=native.ini uv run meson configure build/release \ -Dcpp_args='-Wno-error=deprecated-declarations' just build cd src/lczero_training ln -sfT ../../build/release/_lczero_training.cpython-*-x86_64-linux-gnu.so _lczero_training.so just build-proto ``` ## Training a model To train a model you need: * Training data * A configuration file * Create a checkpoint. * Run the pipeline. ### Training data Unlike the old training pipeline, the new one doesn't need .tar files to be unpacked. While it does support plain `.gz` chunk files, it's not efficient as it stores each individual file name in memory. So instead, use `.tar` files, the tool can index and seek inside them. The tool watches a directory (and its subdirectories) for new files. Terms used: * **Chunk**/Game: A single training game, individual `.gz` file. * **Chunk source**: A file (`.tar` or `.gz`) containing multiple chunks. * **Frame**/Record/Position: A single training position inside a chunk. * **Training tensor**: A single batch of inputs/outputs encoded in NN format for one training step. Incoming data comes as chunks, but for the training we need frames from different games. ### Note on RL vs SL training The tool supports both supervised learning (SL) and reinforcement learning (RL). Here is overview of the configuration differences: | RL Training | SL Training | | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `shuffling_chunk_pool` has relatively small`chunk_pool_size`, which would be used as a sliding window. | `chunk_pool_size` should be larger than all data, so that all data is used for training. | | `training.schedule.chunks_per_network` is non zero — that's how many new chunks to wait for before starting a new epoch. | `chunks_per_network` is zero. Once an epoch is done, it starts a new one immediately. | | RL currently uses "hanse sampling", where currently the entire chunk is loaded and rescored to use just one position from it. The reservoir `shuffling_frame_sampler` is not used in this case. This is currently slow (until we implement caching), so it limits the throughput. | For SL, it better to use two stage sampling: `shuffling_chunk_pool` in non-hanse mode, and then `shuffling_frame_sampler` to shuffle positions within chunks. | ### Creating a checkpoint To create a fresh checkpoint, you'll need `model` and `training` sections of the configuration file to be filled. Then run: ```bash uv run lc0-init --config .textproto --lczero_model .pb.gz ``` The `--lczero_model` parameter is optional. If not given, the network is initialized with random weights. ### Training Run: ```bash CUDA_VISIBLE_DEVICES=0 uv run lc0-tui --config .textproto --logfile train.log ``` Notes: * There are log files both in `tui` and in the configuration file. They are slightly different, TUI log usually is more useful. * In the tool, you can press `q` to quit, and `Ctrl+p` to get a command palette. There, you have one useful command: "Start training immediately". * For multi-GPU training, ensure your batch size is divisible by the number of GPUs. * The overfit utility (`uv run overfit`) does not support multi-GPU. Set `CUDA_VISIBLE_DEVICES` to use only one GPU when running overfit. * Also note that TUI is 100% vibe coded, so you'll see lots of mocks in the UI. :-P ## Tools The repository consists of set of tools, mostly written in Python, but some are in C++. To run a Python tool, use `uv run `. C++ tools are binaries in `build/release`. Most tools need a configuration file as a parameter (see below). ### Python tools | Tool | Description | | ---------------------- | ------------------------------------------------------------------------------------------------------------ | | lc0-daemon | The main training daemon. It acts as a JSONL server, so it's not usable directly from command line yet. | | lc0-tui | A terminal user interface that runs the training daemon. Here is what you have to run. | | lc0-init | Initializes a new training run/checkpoint. | | lc0-migrate-checkpoint | Migrates JAX/Orbax checkpoint after model/training configuration changes. | | lc0-overfit | Runs an overfitting test: takes one batch from the data loader and repeatedly trains on it | | lc0-eval | Evals batches from the data loader on a given checkpoint, can dump inputs/outputs in various formats. | | lc0-leela2jax | | | lc0-describe | | | lc0-test-dataloader | | | lc0-tune-lr | Trains on exponentially increasing learning rate, and outputs losses into csv file. Useful for picking a LR. | | lc0-backfill-metrics | Loads older checkpoints computes metrics for them, and exports them to tensorboard. | | lc0-train | Trains a single epoch (doesn't save or export the model though). Used for benchmarking. | | lc0-weights | Manipulates weight files: arithmetic operations, grafting components, format conversion. See [weights_tool.md](weights_tool.md). | ### C++ tools | Tool | Description | | ---------------------------- | --------------------------------------------------------- | | rescore_chunk | Runs rescorer on a single chunk | | startpos_policy_distribution | | | result_distribution | | | filter_chunks | | | dump_chunk | Dumps the content of a chunk file for debugging purposes. | ## Configuration The configuration is a text protobuf file, with the following sections: | Section | Description | | ----------- | ---------------------------------------------- | | data_loader | Configuration for the data loader. | | model | Model architecture configuration. | | training | Configuration for the training configuration. | | metrics | Metrics to export into tensorboard. | | export | Configuration for exporting the trained model. | Also it has `log_filename` field (where to write the log) and `name` field (must be `little-teapot`). It's recommended to use existing configuration files as a starting point. ### Data loader configuration Data loader is a pipeline that consists of pluggable stages. Here are stages that are currently implemented: | Stage type | Description | Input | Output | | ------------------------- | ------------------------------------------------------------------------------------------------ | ------------ | --------------------- | | `file_path_provider` | Watches a directory for existing and new files. | None | Filenames | | `chunk_source_reader` | Reads and indexes chunk source files (`.tar` or `.gz`). | Filenames | ChunkSources | | `chunk_source_splitter` | Splits chunk sources into smaller chunk sources give the proportion (used for test/train split). | ChunkSources | multiple ChunkSources | | `shuffling_chunk_pool` | Accumulates chunk sources and outputs chunks in shuffled order | ChunkSources | Chunks | | `simple_chunk_extractor` | Unpacks chunks from chunk sources | ChunkSources | Chunks | | `chunk_rescorer` | Rescores chunks | Chunks | Chunks | | `chunk_unpacker` | Extracts positions from chunks | Chunks | Frames | | `shuffling_frame_sampler` | Outputs frames in shuffled order | Frames | Frames | | `tensor_generator` | Converts frames into training batches in numpy tensor format | Frames | Training Tensors | The pipeline ends with one or more outputs, which provide tuples of batched tensors for training. > [!NOTE] > The current format of the training batch is > > * `inputs`: float32 tensor of shape `[batch_size, 112, 8, 8]` > * `policy_target`: float32 tensor of shape `[batch_size, 1862]` > * `value_target`: float32 tensor of shape `[batch_size, 6, 3]`, where 6 rows > are sources of the value (`result`, `best`, `played`, `orig`, `root` and > `st`), and 3 columns are (`q` (w-l), `draw`, `movesleft`). Every stage must have an unique name (may or not be the same as the stage type), and arbitrary number of inputs (depending on the stage type; most have one input). Here is the structure of the data loader configuration: ```textproto stage { name: "file_provider" file_path_provider { # ... } } stage { name: "loader" input: "file_provider" chunk_source_reader { # ... output { name: "myoutput" } } } stage { name: "chunk_shuffler" input: "loader.myoutput" shuffling_chunk_pool { # ... } } # ... stage { name: "tensor_gen" input: "sampler" tensor_generator { batch_size: 256 # ... } } output: "tensor_gen" # unnamed output output: "test:test_tensor_gen" # named output ``` #### Stage output configuration Every stage provides one or more outputs. The configuration of the output is like this (all fields optional): ```textproto output { name: "myoutput" queue_capacity: 8 # default: 4 overflow_behavior: BLOCK } ``` * By default, outputs are not named, but you can name them. * Higher `queue_capacity` allows you to "pre-cache" data, so that when the rate of the producer stage is spiky, the pipeline is not blocked. On the other hand, the data in the queue may be "stale" (i.e. when you train a new network, the data in the queue is still for the old network). * `overflow_behavior` controls what happens when the output queue is full: * `BLOCK`: default and what's needed for most stages. The producer stage is blocked until there is space in the queue. * `DROP_NEW` and `KEEP_NEWEST` drops the data from the queue (either the incoming data, or the oldest data in the queue). These are useful e.g. for auxiliary output of a stage (e.g. validation), so that the auxiliary pipeline doesn't block the main pipeline. ### Stage configurations #### file_path_provider Watches a directory for existing and new files. First it sends all existing files, then sends special "Initial Scan Done" event, and then watches for new files. #### chunk_source_loader Takes the filenames from the input, and loads them as chunk sources. Skips files which are not chunk sources. * `frame_format`: `V6TrainingData` (default) or `V7TrainingData`. #### shuffling_chunk_pool Shuffling chunk pool is the central part of the data loader. In most cases, it is the only stage responsible for shuffling the data. In some cases, you may want to have secondary shuffling_frame_sampler after it (e.g. for SL training). Every chunk source has a "sort key" (currently, it's the file name without path). It's needed to determine the order of chunks to use for the sliding window. * `chunk_pool_size`: The size of the training window, in number of chunks. Even when there are not enough chunks yet, the stage will output chunks from what it has. It will not start producing data until the "initial scan done" event is received from the file_path_provider. * For RL training, typical values are 250k to 5M. * For SL training, it should be larger than all data, so that all data is used for training. ================================================ FILE: docs/architecture.md ================================================ # Architecture Overview The document outlines the architecture of the new Leela Chess Zero training system. The training process involves a Reinforcement Learning (RL) pipeline where new training data is continuously generated. The old script required fresh starts to train a new network on fresher data. Furthermore, the use of TensorFlow involved a very long model compilation time (approx. 1 hour), which dominated the actual training time for a single epoch. The new script will be a single, long-running Python application that automates the entire cycle: 1. Monitors for a sufficient amount of new training data. 2. Triggers and executes the training of a new network. 3. Exports the trained network for use. The core training loop will be implemented in JAX. The data [loading and preprocessing pipeline](loader.md) is a C++ code exposed to Python via pybind11. This C++ library is internally multi-threaded but exposes a simple API to Python (GetNextBatch(), GetStats()). We'd like to have a fancy TUI dashboard to monitor the loading and training process. ## TrainingDaemon The main (in terms of importance) class of the training code is `TrainingDaemon`, which: * Is a jsonl server operating through stdin/stdout, implementing the protocol described in [JSONL IPC Protocol](jsonl.md). * Receives a command to start training with the location of the config file. * Other commands are possible in the future. * Sends periodic progress notifications. * Owns the data loader. * Waits for the data loader to ingest enough new chunks. * Starts the training loop when enough data is available. * Finalizes and uploads the trained network. ## Configuration The configuration is a large nested dataclass structure, which covers: * [Data loader](../src/lczero_training/config/data_loader_config.py) to be passed to the C++ data loader. * Information for the training daemon, e.g. how many chunks to wait for before starting the training. * Model definition, for model builder. * Training parameters, such as batch size, number of epochs, etc. * Export parameters, such as the path to export the trained model to. From user perspective, the configuration is a YAML file, which is parsed into the dataclass structure. ## Data Loader The internals of the data loader are described in detail in [Data Loader](loader.md). From python perspective, it has the following interface: * Constructor takes a `DataLoaderConfig` dataclass. * `GetNextBatch()` returns a tuple of buffer-protocol-compliant tensors. * Later it will have a parameter that specifies wether we need training, test or validation batch. * `GetStats()` returns a DataClass (exact structure TBD) with the current statistics of the data loader. ## TUI TUI is a separate frontend app, implemented as `TrainingTuiApp`, which runs `TrainingDaemon` as a subprocess (and communicates through jsonl via stdin/stdout). * TUI is `Textual`-based app. * Located in `src/lczero_training/tui/`. * UI ideas are described in [TUI](tui.md). * TUI will have a log pane which shows stderr of `TrainingDaemon`. ================================================ FILE: docs/checkpoint_migration.md ================================================ # Checkpoint Migration When part of the model or training setup changes, JAX training state checkpoints may become incompatible with the new setup. For this, we provide a utility to help migrate checkpoints to the new setup. The underlying implementation lives in `src/lczero_training/training/migrate_checkpoint.py` and is exposed via the CLI command `migrate-checkpoint` registered in `pyproject.toml`. ## Command line arguments * `--config`: Path to the RootConfig textproto config. `model`/`training` sections of this config will be used to initialize the new training state. `checkpoint` is the location of the old checkpoint to migrate. * `--new_checkpoint`: Path to save the new checkpoint to. If not set, the tool only checks whether the migration rules fully cover the differences between the old and new training states. * `--overwrite`: If set, allows overwriting existing checkpoint. Also in this case, if `--new_checkpoint` is not set, old checkpoint is used as the new checkpoint. * `--rules_file`: Path to a CheckpointMigrationConfig textproto file containing the migration rules. See below for the format of this file. If not set, no migration rules will be applied (used for debugging, or to check that the old and new states are identical). * `--serialized-model`: If set, use serialized state for a model. Checkpoint already loads serialized as we do not provide schema. This is needed to avoid `GetAttrKey`s. * `--checkpoint_step`: If set, use this step when loading from old checkpoint instead of the latest. * `--new_checkpoint_step`: If set, use this step when saving the new checkpoint instead of copying the old step. ## Creation of the state to compare The checkpoint is loaded into a raw pytree, i.e. template is not passed to it. (old checkpoint with new model would fail to load). Roughly, like this: ```python manager = ocp.CheckpointManager( filepath, options=ocp.CheckpointManagerOptions(create=False), ) state = manager.restore(step=None) ``` The model state is created using `TrainingState.new_from_config` (`src/lczero_training/training/state.py`). If `--serialized-model` is passed, the model state is serialized using `flax.serialization.to_state_dict`. ## Migration rules format The migration rules are specified in a `CheckpointMigrationConfig` textproto file. `CheckpointMigrationConfig` is defined in `proto/checkpoint_migration_config.proto` It contains a list of `CheckpointMigrationRule` messages. Every rules has a `from_path` and `to_path` field (both optional, but at least one must be set). These are string fields, which are json list of strings and integers, and are mapped to pytree KeyPaths: * Integers are used as `SequenceKey` (list/tuple indices). * Strings are used as `DictKey` (dict keys). * If other types is met in the path, an error is raised. * If other key types are met in `PathKey`, the error is also raised. * By default, all keys that are present both in the old and new state are preserved (copied from old to new). * If both `from_path` and `to_path` are set, they must be different. The old values at `from_path` are copied to `to_path` in the new state. * If only `to_path` is set, it means that the new state at `to_path` is taken from the new initialized state. * If only `from_path` is set, it means that the old state at `from_path` is not present in the new state, and is ignored. Without this rule, the migration would fail if the old state has keys that are not present in the new state. * If we want to keep a initialized ("new") value at a path that is also present in the old state, we use two rules: one with only `to_path` to initialize it, and one with only `from_path` to ignore the old value at that path. Note that the `from_path` and `to_path` are prefixes of the actual paths, so the actual subtrees are copied/ignored. ## Working of the migration 1. Both new and old states are created. 2. Both of them are flattened using `jax.tree_util.tree_flatten_with_path` into list of `(KeyPath, value)` pairs and a treedef. 3. Lists of `(KeyPath, value)` are converted to dicts mapping `KeyPath` to `value`. 4. We build a set of `source_paths` (keys of the old state). 5. We build a set of `dest_paths` (keys of the new state). 6. Rules are applied in arbitrary order as following: * If both `from_path` and `to_path` are set, the values prefixed by `from_path` are copied to corresponding `to_path` in the new state. Of course, `to_path` in the new state must exist (i.e. we never create new keys). The copied source paths are removed from `source_paths`. The corresponding destination paths are removed from `dest_paths`. * If only `to_path` is set, we delete all paths prefixed by `to_path` from `dest_paths`. This means that we keep the initialized value at this path. * If only `from_path` is set, we delete all paths prefixed by `from_path` from `source_paths`. This means that we ignore the old value at this path. 7. After all rules are applied, we check that `source_paths` and `dest_paths` are equal. If not, the migration is incomplete, and we raise an error. 8. After this, we copy all remaining `source_paths` (i.e. those that were not mentioned in any rule) to the new state. This means that by default, all keys that are present both in the old and new state are preserved (copied from old to new). 9. Finally, we unflatten the new state dict back to a pytree using the new treedef. If `--new_checkpoint` is set, we save the new state to the specified path. Otherwise, we just print that the migration is possible with the given rules. Instead of just printing one error, the tool should print ALL errors it finds. If should be helpful to fix the rules. ## Implementation notes * There is `justfile`, e.g. `just build-proto` to build the protos. * We use `uv`. Example usage: ```bash uv run migrate-checkpoint --config=~/tmp/lc0/config/overfit.textproto ``` ================================================ FILE: docs/example.textproto ================================================ # Example configuration file for lczero-training # This file demonstrates all available configuration options with their default # values and explanations of what each setting controls. name: "little-teapot" data_loader { stage { name: "file_path_provider" file_path_provider { # Directory with training data files. directory: "/home/crem/tmp/2025-07/lczero-training/data2" output { queue_capacity: 16 # Internal file queue size } } } stage { name: "chunk_source_loader" input: "file_path_provider" chunk_source_loader { threads: 1 # Threads for loading chunks frame_format: V6TrainingData # Training data format (V6TrainingData or V7TrainingData) output { queue_capacity: 16 # Output queue for chunk sources } } } stage { name: "shuffling_chunk_pool" input: "chunk_source_loader" shuffling_chunk_pool { chunk_pool_size: 50000 # Shuffle buffer size (chunks in memory) source_ingestion_threads: 1 # Threads for ingesting new sources chunk_loading_threads: 4 # Threads for loading chunk data output { queue_capacity: 16 # Output queue for shuffled chunks } } } stage { name: "chunk_rescorer" input: "shuffling_chunk_pool" chunk_rescorer { threads: 1 # Threads for chunk rescoring output { queue_capacity: 16 # Output queue for rescored chunks } syzygy_paths: "/path/to/tb" # Tablebase search paths (comma-separated) dist_temp: 1.0 # Policy temperature applied during rescoring dist_offset: 0.0 # Policy offset applied during rescoring dtz_boost: 0.0 # DTZ boost for endgame policy tuning new_input_format: -1 # Keep original input format (-1 disables change) deblunder_threshold: 0.10 # Threshold for policy deblundering adjustments deblunder_width: 0.06 # Width controlling smoothing around threshold } } stage { name: "chunk_unpacker" input: "chunk_rescorer" chunk_unpacker { threads: 1 # Threads for unpacking chunks # Probability of sampling each position within a chunk. position_sampling_rate: 0.03 output { queue_capacity: 16 # Output queue for unpacked frames } } } stage { name: "shuffling_frame_sampler" input: "chunk_unpacker" shuffling_frame_sampler { threads: 1 # Threads for frame sampling reservoir_size_per_thread: 1000000 # Sampling reservoir per thread output { queue_capacity: 16 # Output queue for sampled frames } } } stage { name: "tensor_generator" input: "shuffling_frame_sampler" tensor_generator { threads: 1 # Threads for tensor generation batch_size: 128 # Batch size for tensors output { queue_capacity: 8 # Output queue for batched tensors } } } output: "tensor_generator" } model { defaults { compute_dtype: F32 activation: ACTIVATION_MISH ffn_activation: ACTIVATION_MISH } embedding { dense_size: 512 embedding_size: 1024 dff: 1536 } encoder { num_blocks: 15 dff: 1536 d_model: 1024 heads: 32 smolgen { hidden_channels: 32 hidden_size: 256 gen_size: 256 activation: ACTIVATION_SWISH } } policy_head { name: "vanilla" embedding_size: 1024 d_model: 1024 } value_head { name: "winner" num_channels: 128 } movesleft_head { name: "main" num_channels: 32 } } training { schedule { steps_per_network: 250 chunks_per_network: 50000 } lr_schedule { starting_step: 0 duration_steps: 0 # 0 means indefinite duration lr: 0.001 } # Example multi-phase schedule with warmup: # lr_schedule { # starting_step: 0 # duration_steps: [1500, 500, 0] # Warmup, transition, then indefinite # lr: [0.0, 0.0005, 0.0005] # Start at 0, ramp to 0.0005, hold # transition: [LINEAR, CONSTANT] # Linear warmup, then constant # # Missing transitions default to CONSTANT. Last duration 0 = indefinite. # } checkpoint { path: "/home/crem/tmp/2025-09/lc0_training/checkpoint" max_to_keep: 5 } optimizer { nadamw { beta_1: 0.9 beta_2: 0.98 epsilon: 1e-7 weight_decay: 0.0001 # Rule order matters: first match wins. decay_selector { rule { match: "**/bias" include: false } rule { match: "**/ln*/**" include: false } rule { match: "**/embedding/embedding/**" include: false } rule { match: "**/policy_heads/**" include: true } rule { match: "**/value_heads/**" include: true } rule { match: "**/movesleft_heads/**" include: true } otherwise_include: false } } # Alternative optimizers: # nadam { beta_1: 0.9 beta_2: 0.999 epsilon: 1e-8 } # sgd { momentum: 0.9 nesterov: true } # freeze_selector freezes matching weights (they receive no gradient # updates). # freeze_selector { # rule { match: "**/embedding/**" include: true } # otherwise_include: false # } } max_grad_norm: 10.0 # Global gradient-norm clip; omit or set to 0 to disable. losses { policy { head_name: "vanilla" metric_name: "main_ce" weight: 1.0 illegal_moves: MASK type: CROSS_ENTROPY } policy { head_name: "vanilla" metric_name: "main_kl" weight: 1.0 illegal_moves: MASK type: KL temperature: 1.0 # Softmax temperature applied before KL evaluation } value { head_name: "winner" weight: 1.0 } movesleft { head_name: "main" weight: 1.0 } } } metrics { tensorboard_path: "/tmp/tensorboard/myrun" } export { destination_filename: "/home/crem/tmp/2025-08/lc0_training/exported_models/lc0-{datetime}-{step:08d}.pb.gz" upload_training_run: 3 } ================================================ FILE: docs/heads.md ================================================ # Neural Network Heads Documentation This document describes the various policy and value heads used in the network, their training targets, and any specific scaling or transformations applied during training. ## Policy Heads ### 1. Vanilla Policy (`vanilla`) * **Description**: The standard policy head predicting the best move. * **Training Target**: The `probabilities` vector from the training data (MCTS visit counts). * **Scaling/Transformation**: No specific scaling is applied to the target. The loss function compares the network output (logits) directly against the target probability distribution. * **Loss Function**: Cross-entropy loss (Kullback-Leibler divergence). ### 2. Optimistic Short-Term Policy (`optimistic_st`) * **Description**: A policy head trained to be "optimistic" about the outcome, focusing more on moves that lead to better short-term evaluations. It uses a weighted loss function where positions that the network underestimates (target > prediction) are weighted more heavily. * **Training Target**: The same `probabilities` vector as the `vanilla` head. * **Scaling/Transformation**: * **Mechanism**: The "optimism" is applied as a **sample weight** in the loss function. * **Computation Guide**: 1. **Inputs**: * `v_st_target`: Short-term value target (scalar, from `st` head target). * `v_st_pred`: Short-term value prediction (scalar, from `st` head output). * `v_st_err_pred`: Predicted squared error of the short-term value (scalar, from `st_err` head output). 2. **Standard Deviation Estimation**: * `sigma = sqrt(v_st_err_pred)` 3. **Z-Score Calculation**: * `z = (v_st_target - v_st_pred) / (sigma + 1e-5)` * This measures how many standard deviations the target is away from the prediction. A positive `z` means the position is better than predicted (underestimated). 4. **Weight Calculation**: * `strength = 2.0` (default configuration) * `weight = sigmoid((z - strength) * 3)` * **Interpretation**: * If `z` is large (positive), meaning the target is much higher than predicted (highly underestimated), the weight approaches 1. * If `z` is small or negative (overestimated or accurately predicted), the weight approaches 0. * The `strength` parameter shifts the sigmoid, controlling the threshold of "optimism" required to trigger training. * **Loss Function**: Weighted Cross-entropy loss. `loss = weight * CrossEntropy(target, output)`. ### 3. Soft Policy (`soft`) * **Description**: A policy head trained on a "softened" version of the MCTS probabilities, encouraging exploration or capturing more of the distribution's shape. * **Training Target**: The `probabilities` vector from the training data. * **Scaling/Transformation**: * **Mechanism**: **Temperature scaling** is applied to the target probabilities **inside the loss function**. * **Calculation**: `target = target^(1/temperature)`. * **Location**: This transformation happens in the loss calculation logic (specifically in `correct_policy` helper in `tfprocess.py`), **not** at the tensor generation stage. The tensor generator provides the raw `probabilities`. * **Loss Function**: Cross-entropy loss against the temperature-scaled target. ### 4. Opponent Policy (`opponent`) * **Description**: Predicts the move the opponent actually played in the game. * **Training Target**: `opp_played_idx` (Integer index of the move played by the opponent). * **Scaling/Transformation**: The integer index is converted to a one-hot vector inside the loss function. * **Loss Function**: Cross-entropy loss. ## Value Heads ### 1. Winner (`winner`) * **Description**: Predicts the final game outcome (Win/Draw/Loss). * **Training Target**: A 3-element probability vector derived from `result_q` and `result_d` in the training data. * `Win = (1 + result_q - result_d) / 2` * `Loss = (1 - result_q - result_d) / 2` * `Draw = result_d` * **Scaling/Transformation**: None. * **Loss Function**: Cross-entropy or MSE depending on configuration. ### 2. Q-Value (`q`) * **Description**: Predicts the expected value of the position based on the MCTS search (best Q). * **Training Target**: A 3-element probability vector derived from `best_q` and `best_d` in the training data. * `Win = (1 + best_q - best_d) / 2` * `Loss = (1 - best_q - best_d) / 2` * `Draw = best_d` * **Scaling/Transformation**: None. * **Loss Function**: MSE or Cross-entropy. ### 4. Q-Value Error (`q_err`) * **Description**: Predicts the squared error of the `q` head prediction compared to the target. * **Training Target**: `(q_target - q_pred)^2`. * **Scaling/Transformation**: None. * **Loss Function**: MSE. ### 5. Short-Term Value (`st`) * **Description**: Predicts the short-term evaluation of the position (e.g., from a shallow search or static eval). * **Training Target**: A 3-element probability vector derived from `q_st` and `d_st` in the training data. * `Win = (1 + q_st - d_st) / 2` * `Loss = (1 - q_st - d_st) / 2` * `Draw = d_st` * **Scaling/Transformation**: None. * **Loss Function**: MSE or Cross-entropy. ### 6. Short-Term Value Error (`st_err`) * **Description**: Predicts the squared error of the `st` head prediction compared to the target. Used for calculating uncertainty/variance for the `optimistic_st` head. * **Training Target**: `(st_target - st_pred)^2`. * **Scaling/Transformation**: None. * **Loss Function**: MSE. ## Auxiliary Heads ### 1. Moves Left (`moves_left`) * **Description**: Predicts the number of moves remaining in the game. * **Training Target**: `plies_left` (Float). * **Scaling/Transformation**: * **Mechanism**: The target and output are scaled down by a factor (e.g., 20.0) **inside the loss function** to bring the loss magnitude into a similar range as other losses. * **Loss Function**: Huber loss. ================================================ FILE: docs/index.md ================================================ # Index * [Overview, glossary and file formats](overview.md) — A an overview of the project, including definitions of some terms and file formats used. * [Data Loader](loader.md) — A module for loading, preprocessing, shuffling, and feeding training data. * [Training Tuple Format](training_tuple.md) — A description of the training tuple format used in the project. This is an interface between data loader and model training. * [Command-Line Tools](cli.md) — How to run the tools via `uv run`. ================================================ FILE: docs/loader.md ================================================ # Data Loader The Data Loader is a C++ module (exposed to Python via pybind11) that handles loading, preprocessing, shuffling, and feeding training data for the Leela Chess Zero training process. ## Python Integration The loader has been exposed to Python through pybind11, allowing direct use of the C++ `DataLoader` from Python code. Key aspects: * **Configuration**: Generated protobufs (for example `DataLoaderConfig`) are passed directly to the binding, or via the convenience wrapper `lczero_training.dataloader.make_dataloader`. * **Control Plane**: Use `DataLoader.send_control_message()` with `proto.stage_control_pb2.StageControlRequest` to fan out commands such as chunk-pool anchor updates. * **Memory Management**: Uses `unique_ptr::release()` with `py::return_value_policy::take_ownership` for efficient tensor ownership transfer * **Output Format**: Returns tuple of numpy arrays compatible with JAX through the buffer protocol * **Usage**: `from lczero_training.dataloader import make_dataloader` ## High-Level Overview The Data Loader consists of the following stages connected through a [Queue](../csrc/utils/queue.h): * [FilePathProvider](../csrc/loader/chunk_feed/file_path_provider.h) — Training data discovery worker (watches a directory and provides feed of filenames) * [ChunkSourceLoader](../csrc/loader/chunk_feed/chunk_source_loader.h) — Reads chunks from files, providing a stream of chunks. * [ShufflingChunkPool](../csrc/loader/chunk_feed/shuffling_chunk_pool.h) — Keeps a set of chunks, managing the last `num_chunks` available and removing old ones, and outputting them in shuffled order. * (skip for now) [ChunkValidator](../csrc/loader/chunk_feed/chunk_validator.h) — Filters the chunk stream, filtering out invalid chunks. * [ChunkRescorer](../csrc/loader/stages/chunk_rescorer.h) — Rescores chunks using Syzygy tablebases and configurable policy adjustments. * [ChunkUnpacker](../csrc/loader/chunk_feed/chunk_unpacker.h) — Unpacks chunks into frames, which are then processed by the next stages. * [ShufflingFrameSampler](../csrc/loader/shuffling_frame_sampler.h) — Takes a stream of frames and provides shuffled batches of frames for training, using reservoir sampling. * [TensorGenerator](../csrc/loader/tensor_generator.h) — Takes frames and provides tensor buffers for the training process. ## Metrics All stages expose the following metrics in [DataLoaderMetricsProto](../proto/training_config.proto): * load — for measure how much time the threads are working vs idle. * queue - for monitoring queue statistics. There are the following exceptions: * ShufflingChunkPool * Has two thread pools (indexing and chunk loading), so needs two `load` metrics. * Needs metric (statisticsmetric) for current number of chunk sources. * Needs metric (simple value) for current number of chunks in the pool. * Needs metric (simple value) for pool capacity. * ChunkUnpacker * Needs to track the number of bad chunks (statisticsmetric) * ShufflingFrameSampler * Needs capacity of reservoir (simple value) * Needs current size of reservoir (simple value) ### Adding a new metric To add a new metric, you need to: * Add a field in the relevant proto in [training_config.proto](../proto/training_config.proto) * Add a field in a relevant stage to collect the metric (if needed). * Implement FlushMetrics() function in your stage. It should reset internal state metrics to zero, and return what it accumulated (or latest value). * In [DataLoader::MetricsThread](../csrc/loader/data_loader.cc) call FlushMetrics() of the relevant stage and assign it to proper proto field. * In the proper [Queue or Stage widget](../src/lczero_training/tui/stage_widgets.py) create proper UI elements and update update_metrics() function to update them. * For load metrics, you have to create LoadMetricPauser per thread, e.g. see [ShufflingChunkPool](../csrc/loader/chunk_feed/shuffling_chunk_pool.cc). * For queue metrics, just collect the queue statistics in FlushMetrics(). * For all metrics, both "all time" and "during last second" is provided. The latter is useful to determine rate (items per second). * In general, use StatisticsProtoInt64 (or StatisticsProtoDouble) for distribution metrics, and simple number field for additive metrics like counts. ## TensorGenerator Batch size is configurable in the stage options. The `TensorGenerator` stage takes frames from the input queue and produces tuple of tensors for tensor returned as [TensorTuple](../csrc/utils/tensor.h). The first dimension of every tensor in the tuple is the batch size, and the rest are described in the [Training Tuple Format](training_tuple.md). ## Stage interface All stages implement the similar API and structure, although not sharing any base class. All stages/workers (except pure producers) wait for the input queue to Close(), then Close() output Queue. ```cpp class Stage { public: using InputType = ...; // Type of input data for this stage using OutputType = ...; // Type of output data from this stage // input_queue is omitted in the producer stages like FilePathProvider. Stage(Queue* input_queue, /* other params */); Queue* output(); private: ThreadPool thread_pool_; Queue* input_queue_; Queue output_queue_; }; ``` ## Chunk Set The Chunk Set takes the feed of chunk sources, indexes them, and assigns the chunk range (base; base+num_chunks) to each chunk source. It aims to keep the newest chunk sources that cover the last `chunks_window_` chunks, and removes old chunk sources when new ones are added. On the output side, it returns a stream of chunks within (`last - chunk_window_`, `last`) range, without repetitions. To do that, it utilizes a [StreamShuffler](../csrc/loader/stream_shuffler.h) to which provides the shuffled stream of numbers within the (dynamic) range. The Chunk Set gets the stream of chunks (initial chunks are read in the constructor), and then starts: * Input indexing worker pool. Input indexing worker pool calls `Index()` on each chunk source, and then appends the chunk source to the `chunk_sources_` deque (under mutex). * Chunk output worker pool. Fetches the next number from `stream_shuffler_` under mutex, then reads the chunk from the chunk source using per-source mutex. If `stream_shuffler_` runs out of numbers, it's reset to the range (`last - chunk_window_`, `last`) (and warning message is logged). ## ShufflingFrameSampler The sampler uses reservoir sampling: * It has a reservoir of predefined size (1000000 is quite typical) * Initially it just fills the reservoir with frames from the input queue until it's full. * After that, it picks random frames from the reservoir and outputs them, refilling the used spot from the input queue. * It closes the output queue when either explicit Close() is called or the input queue is closed. * `using FrameType = V6TrainingData;`, use `absl::FixedArray` for the reservoir. ## Anchors in ShufflingChunkPool The training pipeline aims to start training new epoch when a certain number of new chunks are available. To do that, we reset the running counter of chunks in the ShufflingChunkPool when we start waiting for new chunks, and then wait until the counter reaches the desired number. However, when the script restarts, we also want to approximately know how many new chunks to wait for before starting the training. To do that, we use the concept of "anchors". Anchor is just a GetChunkSortKey() of a given chunk that we remember. More specifically, we add the following functions to ShufflingChunkPool: * `std::string ResetAnchor()` — resets the anchor to the latest chunk seen so far, and returns its sort key. Also resets the internal counter of chunks seen since the anchor. * `int ChunksSinceAnchor()` — returns the number of chunks seen since the anchor. * `std::string CurrentAnchor()` — returns the current anchor sort key. * `void SetAnchor(std::string_view)` — is usually called BEFORE starting processing chunks. Does not reset the counter, but sets the anchor to the given value. When the read chunk has the same key as the anchor, the counter is reset to zero. Python clients access this functionality by issuing `StageControlRequest` messages through `DataLoader.send_control_message()`. The daemon pipeline demonstrates how the first chunk-pool response is used to update anchor state. The anchor functionality works differently during initial load vs. ongoing processing: **Initial Load (backward processing):** * Chunks are processed in newest-first order during initial scan * If no anchor is set: count all chunks * If anchor is set: only count chunks newer than the anchor; reset counter to 0 when anchor is encountered **Ongoing Processing (after initial load):** * New chunks are processed as they arrive * Counter increments by GetChunkCount() for each new chunk source * Counter resets to 0 when a chunk source matching the anchor is encountered Metrics: In [ShufflingChunkPoolMetricsProto](../proto/training_metrics.proto) we add: * `int32 chunks_since_anchor` — number of chunks seen since the anchor. Simple numerical field. * `string anchor` — current anchor sort key. In [UI](../src/lczero_training/tui/stage_widgets.py) we: * Update the last_chunk_key display to have a label "Last:" * Add a new row "⚓:" for the anchor key. * Add a new row "Since ⚓:" for chunks_since_anchor. ================================================ FILE: docs/new_stage.md ================================================ # Writing a New Data Loader Stage This guide walks through the lifecycle of adding another stage to the dynamic data loader pipeline. It assumes you are working in the C++ orchestrator (under `csrc/loader/stages/`) and that the Python bindings already consume staged configurations. ## 1. Design the Stage Surface - **Purpose and data flow**: Decide whether the stage produces new data (no upstream input) or transforms items from an existing queue. - **Configuration shape**: Determine which knobs are required during construction (thread counts, capacities, etc.). These become fields on the stage-specific protobuf message. - **Outputs and control hooks**: Clarify what queue type the stage emits and whether it needs control-plane messages. ## 2. Extend the Protobufs - Add a new `message Config` to `proto/data_loader_config.proto`. - For single-output stages, add `optional QueueConfig output = N` to configure the output queue. `QueueConfig` provides `queue_capacity` (default 4), `overflow_behavior` (BLOCK, DROP_NEW, KEEP_NEWEST), and optional `name`. - For multi-output stages, use `repeated QueueConfig output` with parallel configuration arrays (see `ChunkSourceSplitterConfig` for reference). - Update `StageConfig` with an `optional Config` entry so the stage can be referenced from the `repeated stage` list. - If the stage emits custom metrics, extend `StageMetricProto` in `proto/training_metrics.proto`. Prefer the existing `load_metrics`, `queue_metrics`, and `count_metrics` collections when possible. - When the stage needs control requests or responses, extend `proto/stage_control.proto` so they can be carried through `StageControlRequest`/`StageControlResponse`. - Regenerate protobufs (`meson compile -C builddir` or `just build-proto`). ## 3. Choose a Base Class - **Use `SingleInputStage`** when the stage consumes exactly one upstream queue. The helper provides `input_queue()` to access the typed `Queue*` and implements `SetInputs()` to wire the input during initialization. - **Use `SingleOutputStage`** when the stage produces exactly one output queue. The helper manages the output queue, implements `GetOutput()` with name validation, and surfaces the typed `Queue*` via `output_queue()`. - **Most stages inherit from both** `SingleInputStage` and `SingleOutputStage` using virtual inheritance (both base classes virtually inherit from `Stage` to avoid the diamond problem). Example: ```cpp class MyStage : public SingleInputStage, public SingleOutputStage { public: explicit MyStage(const MyStageConfig& config) : SingleInputStage(config), SingleOutputStage(config.output()) {} }; ``` - **Inherit `Stage` directly** when the stage has multiple inputs, multiple outputs, or manages more complex wiring. In that case you must implement `SetInputs()`, input/output discovery, and `GetOutput()` yourself. - Place declarations in `csrc/loader/stages/.h` and definitions in the matching `.cc` file. ## 4. Implement the Stage API - **Constructor**: Initialize base classes with config and `config.output()`. Store additional config fields and initialize worker pools. Avoid starting threads here. - **`SetInputs(absl::Span inputs)`**: Only implement if you inherit from `Stage` directly. `SingleInputStage` provides this automatically and validates that exactly one input is provided. For stages with no inputs, validate that the span is empty. - **`Start()`**: Launch background work. Acquire `Queue::Producer` instances from `output_queue()->CreateProducer()` for emitting data and honour `stop_requested_` flags so shutdown is cooperative. The input queue is available via `input_queue()` at this point. - **`Stop()`**: Close queues via `output_queue()->Close()`, signal workers to exit, and join threads. Remember that downstream stages expect `Queue::Close()` to signal completion. - **`GetOutput(std::string_view name)`**: Only implement if the stage has multiple outputs. `SingleOutputStage` provides this automatically for single-output stages, including name validation. - **`Control()`**: Handle relevant `StageControlRequest` sub-messages and return a populated `StageControlResponse` wrapped in `std::optional`. Return `std::nullopt` for requests the stage does not recognise. ## 5. Report Metrics - **Accumulate state** while workers run (e.g., load metrics, counters, queue statistics). - **`FlushMetrics()`** should snapshot the current values, reset internal counters as needed, and populate `StageMetricProto`. Use helpers like `MetricsFromQueue("output", *output_queue())` to expose queue utilisation under `queue_metrics`, and append load information via `load_metrics`. - For multiple queues or distinct metric groups, add additional entries with meaningful names (`"output"`, `"prefetch"`, etc.) so downstream tooling can pick the right series. - If you rename or split a metric, document the change and update dashboards. For example, `ShufflingChunkPool` now emits `chunks_current` (window size) and `chunks_total` (total indexed chunks) instead of a single `chunks` series; the Grafana panels consuming the old series were repointed to `chunks_current` so the graphs remain accurate. ## 6. Register the Stage - Update `CreateStage` in `csrc/loader/stages/stage_factory.cc` to construct the new class when its config is present. Enforce the "exactly one sub-config" rule by keeping the existing `CountStageConfigs()` logic in sync. - Ensure `meson.build` lists the new source files so the static library rebuilds. ## 7. Wire Up Tests - Add focused unit tests under `csrc/loader/stages/` validating constructor errors, thread lifecycle, metric flushing, and (if applicable) control-plane behaviour. - Provide integration coverage where the stage participates in a small pipeline built from serialized `DataLoaderConfig` messages. - If Python bindings surface stage-specific behaviour, extend the relevant `pytest` suites too. ## 8. Update Documentation and Examples - Document new config fields in `docs/` (for example, augment `docs/loader.md` or create stage-specific notes). - Add sample snippets or textproto fragments showing how to reference the stage in a pipeline. - Mention any new control commands so the daemon/TUI maintainers know how to surface them. Following these steps keeps the stage ecosystem consistent: configurations are validated at construction time, queues remain type-safe, metrics feed the UI, and Python clients continue to operate through the generic factory and control plane. ================================================ FILE: docs/overview.md ================================================ # Overview, Glossary, and File Formats This document serves as a glossary of terms used in the project and describes the file formats utilized. * Training data **frame** is a data structure that holds information needed for the NN training about a single chess position. Currently it's a fixed sized struct, e.g. [V6TrainingData](../libs/lc0/src/trainingdata/trainingdata.h). * **Chunk** is a sequence of frames from a single game. Currently, they are stored in a gzipped file where frames are concatenated together. * **Chunk Source** is a file that contains one or more chunks. In older versions of the code, it was a single gzipped file. In the new version, it also may be an uncompressed .tar file. ================================================ FILE: docs/shuffling_pool_hanse_sampling.md ================================================ # Implement single position sampling in Shuffling Pool This document defines a new way of sampling in [Shuffling Pool](../csrc/loader/stages/shuffling_chunk_pool.h) (and .cc). ## reshuffle_count -> use_count In `TrainingChunk` in [training_chunk.h](../csrc/loader/stages/training_chunk.h) the `reshuffle_count` should be renamed to `use_count`. As with the new sampling method, usage is not necessarily tied to reshuffling. Also the code that [uses it](../csrs/loader/stages/chunk_unpacker.cc) will need to be updated. In `ChunkSourceItem`, instead of one `reshuffle_count` per entire chunk source, we'll have a `std::vector use_counts`, initially filled with zeros. Instead of updating it on reshuffling, we will update it for individual chunks when they are returned. In `GetNextChunkData()` we return old value (i.e. 0 for the first time). Local struct `ShufflingChunkPool::ChunkData` in ../csrc/loader/stages/shuffling_chunk_pool.cc should also have `use_count` instead of `reshuffle_count`. ## Configuration changes In the [config](../proto/data_loader_config.proto), in `ShufflingChunkPoolConfig`, we add the following fields: ```proto message ShufflingChunkPoolConfig { // existing fields... optional uint64 hanse_sampling_threshold = 6; // by default, do not use new sampling. optional double hanse_sampling_gamma = 7 [default = 1.0]; } ``` ## Algorithm changes In addition to `use_counts`, `ChunkSourceItem` will have a new field: `std::vector num_records;`. It will contain the number of records in each chunk, and will act as a cache. Initially, it's filled with zeros. When `record_bound` is not set, the sampling method is the same as before. When `record_bound` is set, we will use the new sampling method. It goes like this: `GetNextChunkData()` doesn't call `LoadChunkData()` right away. Instead, it first checks if `num_records[chunk_index]` is zero. If so, it calls `LoadChunkData()` to load the chunk, and counts the number of records in it (by dividing its size to `sizeof`). Then, we decide whether to return this chunk or to sample again. We do this by drawing a random number `u` uniformly from `[0, 1)`, and comparing it to `p = min(1.0, num_records/hthreshold) ^ gamma`. If `u < p`, we return this chunk (we need to call `LoadChunkData()` if we didn't already) and increment use_count. Otherwise, we sample again (i.e. pick a new `chunk_index` and repeat the process) — without incrementing `use_count`. ================================================ FILE: docs/training_tuple.md ================================================ # Training Tuple Format The `convert_v6_to_tuple` function in `tf/chunkparser.py` processes training data and produces a 5-element tuple: `(planes, probs, winner, best_q, plies_left)`. When these raw byte strings are interpreted as NumPy arrays, they have the following shapes for each training example: 1. **`planes`**: `(112, 64)` as a `float32` array. * This represents the board state as 112 feature planes, each of size 8x8 (64). The original 104 planes from the input are augmented with 8 additional planes for information like castling rights, side to move, the rule 50 count, and board edge detection. 2. **`probs`**: `(1858,)` as a `float32` array. * This corresponds to the `float probabilities[1858]` member in the `V6TrainingData` C++ struct, representing the policy probabilities for all possible moves. 3. **`winner`**: `(3,)` as a `float32` array. * This holds the game's outcome from the current player's perspective, representing the probabilities for a win, draw, and loss, respectively. 4. **`best_q`**: `(3,)` as a `float32` array. * Similar to `winner`, this stores the value of the position after search (the Q-value), also represented as win, draw, and loss probabilities. 5. **`plies_left`**: A scalar `float32`. * This value represents the estimated number of plies remaining until the end of the game. ================================================ FILE: docs/tui.md ================================================ # UI Design The application will present a single-screen dashboard with a classic blue background, organized into several key, always-visible panes. ## 1. Overall Layout The screen is divided into four main sections: 1. **Header Bar (Top):** A slim, single-line bar at the very top. 2. **Data Pipeline Pane (Top-Left/Main):** The largest and most detailed pane. 3. **JAX Training Status Pane (Right):** A pane dedicated to live training metrics. 4. **Log Pane (Bottom):** A pane for displaying raw log output. ## 2. Header Bar This bar provides high-level, global status at a glance. - **Uptime:** Total wall-clock time since the script was launched. - **Overall Stage:** The current high-level state of the application. Will display one of: `WAITING FOR DATA`, `TRAINING`, `EXPORTING`, or `ERROR`. ### 3. Data Pipeline Pane This pane visualizes the flow of data through the C++ pipeline. The new design is a vertically scrollable list where every stage and queue consumes one row (two only when extra detail is needed). Rows are rendered in the same order as traffic flows through the loader, preserving the mental model of the pipeline without relying on borders or grid positioning. - **Pipeline Stages (Rows):** Each stage of the C++ pipeline is rendered as a single line. - **Stage Heading:** The label uses the canonical stage name reported in the metrics, so newly added stages automatically appear without additional UI wiring. - **Load Metrics:** Stages with thread pools render load in `load active/total` format. - **Specific Stats:** Each metric is rendered as an individual "chip" widget, e.g. skipped file counters or pool sizes, so additional detail can be added without breaking alignment. - **Queues (Rows):** Each stage row is followed by a queue row that surfaces the metrics of the outgoing queue. - **Queue Heading:** The row label is the canonical stage name reported by the daemon. The first chip shows the queue name when it differs from the default. - **Throughput:** A chip displays the 1-second `items/s` rate and turns red when the rate drops to zero. - **Totals:** A chip shows the lifetime count of elements that passed through the queue, formatted with apostrophes. - **Fill State:** A horizontal progress bar visualises the average queue fill against capacity, followed by a chip with the numeric `avg/capacity` display. Unknown values fall back to `--`. - **Train/Validation/Test Splitter:** - The pipeline view will show a "Stream Splitter" stage. - Hotkeys (e.g., F1, F2, F3) will allow the user to instantly switch the view - After this stage, the UI will display the pipeline stats for **one** stream at a time (defaulting to 'Training'). to show the stats for the 'Training', 'Validation', or 'Test' streams. ### 4. Training schedule/pipeline pane The area below the Data Pipeline Pane is split horizontally into two sections: Training Schedule (left) and JAX Training Status (right). The Training Schedule pane shows: - **Combined uptime and stage line**: "Uptime: 2d 14:30:45 Stage: TRAINING" - Uptime includes days when >24 hours (format: "2d 14:30:45") - Stage shows current training state (WAITING_FOR_DATA, TRAINING, EXPORTING, ERROR) - **Completed epochs**: Simple counter of epochs completed since daemon start - **New chunks progress bar**: Shows chunks collected since training start vs. target - Indeterminate state when target is unknown (0) - **Training time progress bar**: Current training time vs. previous training duration - Indeterminate state when no previous duration exists - **Cycle time progress bar**: Current cycle time vs. previous cycle duration - Indeterminate state when no previous duration exists **Implementation details:** - **Header bar**: Completely empty (no content) - **Data structure**: `TrainingScheduleData` dataclass with all timing fields: - `current_stage: TrainingStage` (enum) - `completed_epochs_since_start: int` - `new_chunks_since_training_start: int` - `chunks_to_wait: int` - `total_uptime_seconds: float` - `current_training_time_seconds: float` - `previous_training_time_seconds: float` - `current_cycle_time_seconds: float` - `previous_cycle_time_seconds: float` - **Timing computation**: All timing values computed in daemon, not TUI - **Progress bars**: Show indeterminate state when maximum values ≤ 0 - **Layout**: Compact single-line widgets with no extra padding/margins - **Files**: - `training_widgets.py`: Widget implementations - `pipeline.py`: Training state tracking and data collection - `daemon.py`: Metrics collection and transmission - `messages.py`: Protocol definitions with enum serialization support ### 5. JAX Training Status Pane This pane is dedicated to the live status of an active JAX training run. It remains blank or shows summary info when the system is not actively training. - **Epoch Progress:** A progress bar showing completion of the current epoch (`Step 12345 / 50000`). - **Performance Metrics:** - Steps per second: `345.6 steps/s`. - Estimated Time Remaining (ETR) for the current run. - Total wall time spent on the current training run. - **Loss Values (Numerical Only):** - A prominent display of the **Total Loss**. - A compact 2-column grid displaying the individual values for the **7 Head Losses**. ### 6. Log Pane A pane across the bottom of the screen. - **Content:** A direct feed of all output sent to `stderr` from any part of the application (Python or C++). - **Functionality:** The pane will be scrollable and will hold a fixed number of lines (e.g., 1000) to prevent unbounded memory usage, discarding the oldest lines as new ones arrive. ================================================ FILE: docs/weights_tool.md ================================================ # lc0-weights - Weight Manipulation Tool ## Overview `lc0-weights` is a command-line tool and Python library for manipulating Leela Chess Zero neural network weight files. It enables arithmetic operations on networks, component grafting, and format conversion without requiring JAX or other heavy dependencies. **Key capabilities:** * **Arithmetic operations**: Add, subtract, multiply networks (e.g., model interpolation/averaging) * **Grafting**: Replace specific network components (policy heads, value heads, encoder layers) * **Format conversion**: Convert between LINEAR16, FLOAT16, and BFLOAT16 encodings * **Pure Python/NumPy**: No JAX or TensorFlow dependencies required ## Quick Start Interpolate two networks with equal weights: ```bash # Using inline expression uv run lc0-weights \ --expr "output = weights('network_a.pb.gz') * 0.5 + weights('network_b.pb.gz') * 0.5" \ --output interpolated.pb.gz # Using a script file echo "output = weights('network_a.pb.gz') * 0.5 + weights('network_b.pb.gz') * 0.5" > interpolate.py uv run lc0-weights interpolate.py --output interpolated.pb.gz # Using stdin echo "output = weights('network_a.pb.gz') * 0.5 + weights('network_b.pb.gz') * 0.5" | \ uv run lc0-weights --output interpolated.pb.gz ``` ## Command-Line Interface ### Arguments | Argument | Required | Description | | ------------ | -------- | -------------------------------------------------------------------- | | `--expr` | No | Python expression to execute | | `script` | No | Path to Python script file (positional, used if --expr not given) | | `--input` | No | Pre-load input as `NAME=PATH` (can be used multiple times) | | `--output` | No | Output path (if `output` variable is set in expression) | | `--encoding` | No | Output encoding format: LINEAR16, FLOAT16 (default), or BFLOAT16 | **Note:** You must provide either `--expr`, a script file path, or pipe input via stdin. ### Available Functions in --expr When executing expressions with `--expr`, the following are available: * `weights(path)`: Load a weight file * `save(net, path, encoding='FLOAT16')`: Save a network * `np`: NumPy module * `lc0`: Protobuf module (for accessing constants) ### CLI Examples #### Simple Interpolation **Using inline expression:** ```bash uv run lc0-weights \ --expr "output = weights('A.pb.gz') * 0.5 + weights('B.pb.gz') * 0.5" \ --output result.pb.gz ``` **Using a script file:** ```bash # Create script file cat > interpolate.py << 'EOF' A = weights('A.pb.gz') B = weights('B.pb.gz') output = A * 0.5 + B * 0.5 EOF # Run script uv run lc0-weights interpolate.py --output result.pb.gz ``` **Using stdin:** ```bash echo "output = weights('A.pb.gz') * 0.5 + weights('B.pb.gz') * 0.5" | \ uv run lc0-weights --output result.pb.gz ``` #### Using Input Aliases Pre-load networks to simplify expressions: ```bash uv run lc0-weights \ --input A=network_a.pb.gz \ --input B=network_b.pb.gz \ --expr "output = A * 0.9 + B * 0.1" \ --output result.pb.gz ``` #### Grafting a Policy Head Replace the policy head of one network with another: **Using inline expression:** ```bash uv run lc0-weights --expr " base = weights('base_network.pb.gz') donor = weights('network_with_better_policy.pb.gz') base.weights.policy = donor.weights.policy base.save('grafted.pb.gz') " ``` **Using a script file (recommended for multi-line operations):** ```bash # Create script cat > graft_policy.py << 'EOF' base = weights('base_network.pb.gz') donor = weights('network_with_better_policy.pb.gz') base.weights.policy = donor.weights.policy base.save('grafted.pb.gz') EOF # Run script uv run lc0-weights graft_policy.py ``` #### Format Conversion Convert a network to BFLOAT16 encoding: ```bash uv run lc0-weights \ --input net=network.pb.gz \ --expr "output = net" \ --output converted.pb.gz \ --encoding BFLOAT16 ``` #### Complex Expression Weighted average with custom formula: ```bash uv run lc0-weights \ --input A=net1.pb.gz \ --input B=net2.pb.gz \ --input C=net3.pb.gz \ --expr "output = A * 0.5 + B * 0.3 + C * 0.2" \ --output averaged.pb.gz ``` ## Python Library Usage ### Importing ```python from lczero_training.tools import load_weights, save_weights ``` ### Loading and Saving Weights ```python # Load a network net = load_weights("network.pb.gz") # Save with different encoding save_weights(net, "output.pb.gz", encoding="FLOAT16") ``` ### Arithmetic Operations ```python # Load networks net_a = load_weights("network_a.pb.gz") net_b = load_weights("network_b.pb.gz") # Interpolation (model averaging) interpolated = net_a * 0.7 + net_b * 0.3 # Addition combined = net_a + net_b # Subtraction difference = net_a - net_b # Scalar multiplication scaled = net_a * 0.5 # Save result save_weights(interpolated, "result.pb.gz") ``` ### Accessing and Modifying Components ```python net = load_weights("network.pb.gz") # Access nested weight arrays q_weights = net.weights.encoder[0].mha.q_w.value # Returns NumPy array print(q_weights.shape) # Modify weights net.weights.encoder[0].mha.q_w.value = q_weights * 1.1 # Save modified network save_weights(net, "modified.pb.gz") ``` ### Grafting Components ```python base = load_weights("base.pb.gz") donor = load_weights("donor.pb.gz") # Replace policy head base.weights.policy = donor.weights.policy # Replace value head base.weights.value_heads = donor.weights.value_heads # Replace specific encoder layer base.weights.encoder[0] = donor.weights.encoder[0] save_weights(base, "grafted.pb.gz") ``` ## Weight Encoding Formats The tool supports three encoding formats for weight storage: | Format | Description | Precision | File Size | | -------- | -------------------------------------------------------------------------------- | --------- | --------- | | LINEAR16 | Quantized 16-bit integer with min/max range. Default for Lc0. | ~4 digits | Smallest | | FLOAT16 | Native IEEE 754 half-precision floating point. | ~3 digits | Medium | | BFLOAT16 | Brain float 16 (truncated float32). Better range than FLOAT16, less precision. | ~2 digits | Medium | **Notes:** * LINEAR16 provides good compression with acceptable precision for neural networks * FLOAT16 is the default for this tool (good balance of precision and size) * BFLOAT16 is useful when training range matters more than mantissa precision * All formats are converted to float32 when loaded for arithmetic operations ## Common Use Cases ### Network Interpolation (Model Averaging) Combine two networks to create a smoother model or blend different training runs: ```python from lczero_training.tools import load_weights, save_weights net1 = load_weights("run1_final.pb.gz") net2 = load_weights("run2_final.pb.gz") # Average the networks averaged = net1 * 0.5 + net2 * 0.5 save_weights(averaged, "averaged_network.pb.gz") ``` ### Exponential Moving Average (EMA) Update a running average network with a new checkpoint: ```python ema_net = load_weights("ema.pb.gz") new_net = load_weights("latest_checkpoint.pb.gz") # EMA with decay 0.999 ema_updated = ema_net * 0.999 + new_net * 0.001 save_weights(ema_updated, "ema.pb.gz") ``` ### Policy Head Replacement Replace a network's policy head (useful for policy distillation): ```python student = load_weights("student_network.pb.gz") teacher = load_weights("teacher_network.pb.gz") # Replace student's policy head with teacher's student.weights.policy_heads = teacher.weights.policy_heads save_weights(student, "student_with_teacher_policy.pb.gz") ``` ### Extracting Network Statistics ```python net = load_weights("network.pb.gz") # Get statistics from first encoder layer layer = net.weights.encoder[0].mha.q_w.value print(f"Shape: {layer.shape}") print(f"Mean: {layer.mean():.6f}") print(f"Std: {layer.std():.6f}") print(f"Min: {layer.min():.6f}") print(f"Max: {layer.max():.6f}") ``` ### Format Conversion for Size Optimization ```python from lczero_training.tools import load_weights, save_weights # Load network (any format) net = load_weights("large_network.pb.gz") # Save with more aggressive compression save_weights(net, "compressed_network.pb.gz", encoding="LINEAR16") ``` ## Advanced Usage ### Accessing Nested Structures The weight wrapper provides Pythonic access to the nested protobuf structure: ```python net = load_weights("network.pb.gz") # Access input embedding weights input_weights = net.weights.input.weights.value # Access specific encoder layer encoder_layer_5 = net.weights.encoder[5] # Access multi-head attention components q_weights = net.weights.encoder[0].mha.q_w.value k_weights = net.weights.encoder[0].mha.k_w.value v_weights = net.weights.encoder[0].mha.v_w.value # Access policy head policy_weights = net.weights.policy_heads.vanilla.ip_pol_w.value ``` ### Complex Weighted Combinations ```python nets = [load_weights(f"checkpoint_{i}.pb.gz") for i in range(5)] weights = [0.1, 0.15, 0.2, 0.25, 0.3] # More weight to recent checkpoints result = sum(net * w for net, w in zip(nets, weights)) save_weights(result, "weighted_ensemble.pb.gz") ``` ### Selective Component Grafting ```python base = load_weights("base.pb.gz") donor = load_weights("donor.pb.gz") # Replace only the first 10 encoder layers for i in range(10): base.weights.encoder[i] = donor.weights.encoder[i] save_weights(base, "partial_graft.pb.gz") ``` ### Working with NumPy Arrays Directly All weight access returns NumPy arrays, allowing arbitrary transformations: ```python net = load_weights("network.pb.gz") # Get layer weights layer_weights = net.weights.encoder[0].mha.q_w.value # Apply custom transformation import numpy as np layer_weights_normalized = layer_weights / np.linalg.norm(layer_weights, axis=-1, keepdims=True) # Write back net.weights.encoder[0].mha.q_w.value = layer_weights_normalized save_weights(net, "normalized.pb.gz") ``` ## Implementation Details ### Lazy Loading Weights are decoded from their compressed format only when accessed, and cached for subsequent use. This makes the tool memory-efficient when working with large networks. ### File Format Support Both `.pb` (uncompressed protobuf) and `.pb.gz` (gzip-compressed) formats are supported. The tool automatically detects the format based on the file extension. ### Arithmetic Semantics * Operations are element-wise across all matching layers * Networks must have compatible structures (same number of encoder layers, etc.) * Results are computed in float32 precision regardless of input encoding ================================================ FILE: init.sh ================================================ #!/usr/bin/env bash protoc --proto_path=libs/lc0 --python_out=tf proto/net.proto touch tf/proto/__init__.py ================================================ FILE: justfile ================================================ # List available commands default: @just --list # Check if all C++ files in csrc/ are formatted according to clang-format check-cpp: find csrc/ -name "*.cpp" -o -name "*.cc" -o -name "*.cxx" -o -name "*.h" -o -name "*.hpp" | xargs clang-format --dry-run --Werror # Format all C++ files in csrc/ using clang-format format-cpp: find csrc/ -name "*.cpp" -o -name "*.cc" -o -name "*.cxx" -o -name "*.h" -o -name "*.hpp" | xargs clang-format -i # Check if all protobuf files are formatted according to clang-format check-proto: find proto/ -name "*.proto" | xargs clang-format --dry-run --Werror # Format all protobuf files using clang-format format-proto: find proto/ -name "*.proto" | xargs clang-format -i # Build Python protobuf files build-proto: mkdir -p src/proto touch src/proto/__init__.py uv run python -m grpc_tools.protoc \ --proto_path=. \ --proto_path=libs/lc0 \ --python_out=src/ \ --pyi_out=src/ \ proto/*.proto uv run python -m grpc_tools.protoc \ --proto_path=. \ --proto_path=libs/lc0 \ --python_out=src/ \ --pyi_out=src/ \ proto/net.proto \ proto/onnx.proto \ proto/hlo.proto # Check if all Python files in src/ are formatted according to ruff check-python: uv run ruff check src/ uv run ruff check --select I src/ uv run ruff format --check src/ uv run mypy -p lczero_training --disallow-untyped-defs --disallow-incomplete-defs # Format all Python files in src/ using ruff format-python: uv run ruff check --fix --select I src/ uv run ruff format src/ uv run ruff check --fix src/ format: format-cpp format-proto format-python # Setup meson build directory with clang setup-build: CXX=clang++ CC=clang uv run meson setup build/release \ --buildtype=release --native-file=native.ini # Build the project build: uv run meson compile -C build/release/ # Create symlink for the built extension module mksymlink: cd src/lczero_training && ln -sfT ../../build/release/_lczero_training.cpython-*-x86_64-linux-gnu.so _lczero_training.so # Delete build/release, re-setup, rebuild, and re-link rebuild: && setup-build build mksymlink rm -rf build/release # Run tests test-cpp: uv run meson test -C build/release/ test-python: uv run pytest test: test-cpp test-python check: check-cpp check-proto check-python # Run all checks (formatting, build, and tests) pre-commit: build-proto check build test ================================================ FILE: meson.build ================================================ project( 'lczero-training', 'cpp', version : '0.1', meson_version : '>= 1.3.0', default_options : [ 'warning_level=3', 'cpp_std=c++20', ], ) # Allow Clang nullability extensions when using Clang compiler cpp_compiler = meson.get_compiler('cpp') if cpp_compiler.get_id() == 'clang' add_project_arguments('-Wno-nullability-extension', language : 'cpp') endif # External dependencies zlib_dep = dependency('zlib') # Python and PyBind11 dependencies for Python extension python3 = import('python').find_installation() pybind11_dep = dependency('pybind11') # Abseil dependencies always resolved from the wrap subproject. absl_proj = subproject('abseil-cpp') absl_dep_specs = [ ['log', 'absl_log_dep'], ['log_initialize', 'absl_log_dep'], ['check', 'absl_log_dep'], ['hash', 'absl_hash_dep'], ['raw_hash_set', 'absl_container_dep'], ['synchronization', 'absl_synchronization_dep'], ['random_random', 'absl_random_dep'], ['flags', 'absl_flags_dep'], ['flags_parse', 'absl_flags_dep'], ['throw_delegate', 'absl_base_dep'], ] absl_deps = {} foreach spec : absl_dep_specs absl_deps += { spec[0].underscorify() : absl_proj.get_variable(spec[1]).as_system() } endforeach # Test dependencies gtest_dep = dependency('gtest').as_system() gtest_main_dep = dependency('gtest_main').as_system() # Gaviota tablebase subproject gaviota_dep = subproject('gaviotatb').get_variable('gaviotatb_dep') # Common dependency sets external_deps = [zlib_dep] core_absl_deps = [ absl_deps['log'], absl_deps['check'], absl_deps['hash'], absl_deps['raw_hash_set'], absl_deps['synchronization'], absl_deps['random_random'], ] loader_deps = external_deps + core_absl_deps + [gaviota_dep] main_deps = loader_deps + [absl_deps['log_initialize']] # test_deps will be defined after proto_dep cli_deps = [ absl_deps['log'], absl_deps['log_initialize'], absl_deps['flags'], absl_deps['flags_parse'], ] # Protobuf compilation setup compile_proto = find_program('libs/lc0/scripts/compile_proto.py') proto_gen = generator(compile_proto, output: ['@BASENAME@.pb.h'], arguments : [ '--proto_path=@CURRENT_SOURCE_DIR@', '--cpp_out=@BUILD_DIR@', '@INPUT@']) lc0_proto_gen = generator(compile_proto, output: ['@BASENAME@.pb.h'], arguments : [ '--proto_path=' + join_paths(meson.current_source_dir(), 'libs/lc0'), '--cpp_out=@BUILD_DIR@', '@INPUT@']) includes = include_directories('csrc', 'libs/lc0/src') add_project_arguments('-DNO_PEXT', language : 'cpp') rescorer_files = [ 'libs/lc0/src/chess/board.cc', 'libs/lc0/src/chess/gamestate.cc', 'libs/lc0/src/chess/position.cc', 'libs/lc0/src/neural/decoder.cc', 'libs/lc0/src/neural/encoder.cc', 'libs/lc0/src/trainingdata/reader.cc', 'libs/lc0/src/trainingdata/trainingdata.cc', 'libs/lc0/src/trainingdata/writer.cc', 'libs/lc0/src/utils/commandline.cc', 'libs/lc0/src/utils/configfile.cc', 'libs/lc0/src/utils/esc_codes.cc', 'libs/lc0/src/utils/optionsdict.cc', 'libs/lc0/src/utils/optionsparser.cc', 'libs/lc0/src/utils/random.cc', 'libs/lc0/src/utils/string.cc', ] if host_machine.system() == 'windows' rescorer_files += 'libs/lc0/src/utils/filesystem.win32.cc' else rescorer_files += 'libs/lc0/src/utils/filesystem.posix.cc' endif files = [ 'csrc/loader/chunk_source/debug_chunk_source.cc', 'csrc/loader/chunk_source/rawfile_chunk_source.cc', 'csrc/loader/chunk_source/tar_chunk_source.cc', 'csrc/loader/data_loader_metrics.cc', 'csrc/loader/data_loader.cc', 'csrc/loader/stages/chunk_rescorer.cc', 'csrc/loader/stages/chunk_source_loader.cc', 'csrc/loader/stages/chunk_source_splitter.cc', 'csrc/loader/stages/chunk_unpacker.cc', 'csrc/loader/stages/file_path_provider.cc', 'csrc/loader/stages/join_stage.cc', 'csrc/loader/stages/position_sampling.cc', 'csrc/loader/stages/position_sampling.cc', 'csrc/loader/stages/shuffling_chunk_pool.cc', 'csrc/loader/stages/shuffling_frame_sampler.cc', 'csrc/loader/stages/simple_chunk_extractor.cc', 'csrc/loader/stages/stage_factory.cc', 'csrc/loader/stages/stage.cc', 'csrc/loader/stages/tensor_generator.cc', 'csrc/utils/gz.cc', 'csrc/utils/stream_shuffler.cc', 'csrc/utils/training_data_printer.cc', 'libs/lc0/src/syzygy/syzygy.cc', 'libs/lc0/src/trainingdata/rescorer.cc', 'libs/lc0/src/utils/files.cc', 'libs/lc0/src/utils/logging.cc', 'libs/lc0/src/utils/protomessage.cc', ] + rescorer_files # Process protobuf files for C++ proto_files = [ proto_gen.process('proto/data_loader_config.proto', preserve_path_from : meson.current_source_dir()), proto_gen.process('proto/training_metrics.proto', preserve_path_from : meson.current_source_dir()), proto_gen.process('proto/stage_control.proto', preserve_path_from : meson.current_source_dir()), lc0_proto_gen.process('libs/lc0/proto/net.proto', preserve_path_from : join_paths(meson.current_source_dir(), 'libs/lc0')), lc0_proto_gen.process('libs/lc0/proto/onnx.proto', preserve_path_from : join_paths(meson.current_source_dir(), 'libs/lc0')), lc0_proto_gen.process('libs/lc0/proto/hlo.proto', preserve_path_from : join_paths(meson.current_source_dir(), 'libs/lc0')) ] # Create a dependency for protobuf files that tests can use proto_dep = declare_dependency(sources: proto_files) test_deps = [gtest_dep, gtest_main_dep, proto_dep] files += proto_files loader_lib = static_library( 'loader', files, include_directories : includes, dependencies : loader_deps, ) exe = executable( 'loader', 'csrc/loader/loader_main.cpp', include_directories : includes, dependencies : cli_deps + [proto_dep], link_with : loader_lib, ) stream_shuffler_test = executable( 'stream_shuffler_test', 'csrc/utils/stream_shuffler_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['random_random']], link_with : loader_lib, ) queue_test = executable( 'queue_test', 'csrc/utils/queue_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], ) file_path_provider_test = executable( 'file_path_provider_test', 'csrc/loader/stages/file_path_provider_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) chunk_source_loader_test = executable( 'chunk_source_loader_test', 'csrc/loader/stages/chunk_source_loader_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization']], link_with : loader_lib, ) shuffling_chunk_pool_test = executable( 'shuffling_chunk_pool_test', 'csrc/loader/stages/shuffling_chunk_pool_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) # simple_chunk_extractor_test = executable( # 'simple_chunk_extractor_test', # 'csrc/loader/stages/simple_chunk_extractor_test.cc', # include_directories : includes, # dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], # link_with : loader_lib, # ) chunk_rescorer_test = executable( 'chunk_rescorer_test', 'csrc/loader/stages/chunk_rescorer_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) chunk_unpacker_test = executable( 'chunk_unpacker_test', 'csrc/loader/stages/chunk_unpacker_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) shuffling_frame_sampler_test = executable( 'shuffling_frame_sampler_test', 'csrc/loader/stages/shuffling_frame_sampler_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['random_random']], link_with : loader_lib, ) tensor_test = executable( 'tensor_test', 'csrc/utils/tensor_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['throw_delegate']], ) tensor_generator_test = executable( 'tensor_generator_test', 'csrc/loader/stages/tensor_generator_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization']], link_with : loader_lib, ) stats_test = executable( 'stats_test', 'csrc/utils/metrics/stats_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization']], link_with : loader_lib, ) load_metric_test = executable( 'load_metric_test', 'csrc/utils/metrics/load_metric_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization']], link_with : loader_lib, ) stage_factory_test = executable( 'stage_factory_test', 'csrc/loader/stages/stage_factory_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) data_loader_test = executable( 'data_loader_test', 'csrc/loader/data_loader_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) join_stage_test = executable( 'join_stage_test', 'csrc/loader/stages/join_stage_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) test('stream_shuffler_test', stream_shuffler_test) test('queue_test', queue_test) test('file_path_provider_test', file_path_provider_test) test('chunk_source_loader_test', chunk_source_loader_test) chunk_source_splitter_test = executable( 'chunk_source_splitter_test', 'csrc/loader/stages/chunk_source_splitter_test.cc', include_directories : includes, dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']], link_with : loader_lib, ) test('chunk_source_splitter_test', chunk_source_splitter_test) test('shuffling_chunk_pool_test', shuffling_chunk_pool_test) # test('simple_chunk_extractor_test', simple_chunk_extractor_test) test('chunk_rescorer_test', chunk_rescorer_test) test('chunk_unpacker_test', chunk_unpacker_test) test('shuffling_frame_sampler_test', shuffling_frame_sampler_test) test('tensor_test', tensor_test) test('tensor_generator_test', tensor_generator_test) test('stats_test', stats_test) test('load_metric_test', load_metric_test) test('stage_factory_test', stage_factory_test) test('data_loader_test', data_loader_test) test('join_stage_test', join_stage_test) file_path_provider_main = executable( 'file_path_provider_main', 'csrc/loader/stages/file_path_provider_main.cc', include_directories : includes, dependencies : cli_deps + [proto_dep], link_with : loader_lib, ) dump_chunk = executable( 'dump_chunk', 'csrc/tools/dump_chunk_main.cc', include_directories : includes, dependencies : cli_deps + [zlib_dep], link_with : loader_lib, ) filter_chunks = executable( 'filter_chunks', 'csrc/tools/filter_chunks_main.cc', include_directories : includes, dependencies : cli_deps + [proto_dep, zlib_dep], link_with : loader_lib, ) result_distribution = executable( 'result_distribution', 'csrc/tools/result_distribution_main.cc', include_directories : includes, dependencies : cli_deps + [proto_dep, absl_deps['synchronization']], link_with : loader_lib, ) startpos_policy_distribution = executable( 'startpos_policy_distribution', 'csrc/tools/startpos_policy_distribution_main.cc', include_directories : includes, dependencies : cli_deps + [proto_dep], link_with : loader_lib, ) rescore_chunk = executable( 'rescore_chunk', 'csrc/tools/rescore_chunk_main.cc', include_directories : includes, dependencies : cli_deps + [proto_dep], link_with : loader_lib, ) position_weight_stats = executable( 'position_weight_stats', 'csrc/tools/position_weight_stats_main.cc', include_directories : includes, dependencies : cli_deps + [proto_dep], link_with : loader_lib, ) # Python extension module python3.extension_module( '_lczero_training', 'csrc/loader/pybind_module.cc', include_directories : includes, dependencies : [pybind11_dep, proto_dep, absl_deps['log_initialize']] + loader_deps, link_with : loader_lib, install : true, subdir : 'lczero_training', ) ================================================ FILE: native.ini ================================================ [binaries] python = '@GLOBAL_SOURCE_ROOT@/.venv/bin/python' ================================================ FILE: proto/checkpoint_migration_config.proto ================================================ syntax = "proto3"; package lczero_training.proto; message CheckpointMigrationRule { // Path in the old state pytree. string from_path = 1; // Path in the new state pytree. string to_path = 2; } message CheckpointMigrationConfig { repeated CheckpointMigrationRule rule = 1; } ================================================ FILE: proto/data_loader_config.proto ================================================ syntax = "proto2"; package lczero.training; // Configuration for output queue used by stages. message QueueConfig { // Queue overflow behavior. enum OverflowBehavior { BLOCK = 0; DROP_NEW = 1; KEEP_NEWEST = 2; } // Optional name for the output (used for multi-output stages). optional string name = 1; // Size of the output queue. optional uint64 queue_capacity = 2 [default = 4]; // Overflow behavior of the output queue. optional OverflowBehavior overflow_behavior = 3; } // Configuration for file path provider that watches directories for new // training files. Maps to FilePathProviderOptions in // csrc/loader/chunk_feed/file_path_provider.h message FilePathProviderConfig { // Path to directory containing training data files. optional string directory = 1; // Output queue configuration. optional QueueConfig output = 2; } // Configuration for chunk source loader that converts file paths to chunk // sources. Maps to ChunkSourceLoaderOptions in // csrc/loader/chunk_feed/chunk_source_loader.h message ChunkSourceLoaderConfig { enum FrameFormat { V6TrainingData = 0; V7TrainingData = 1; } // Number of worker threads for loading. optional uint64 threads = 1 [default = 1]; // Output queue configuration. optional QueueConfig output = 2; // Training data frame format. optional FrameFormat frame_format = 3; } message PositionSamplingConfig { optional float diff_focus_q_weight = 1 [default = 0.0]; optional float diff_focus_pol_scale = 2 [default = 1.0]; optional float diff_focus_alpha = 3 [default = 1.0]; optional float diff_focus_beta = 4 [default = 0.0]; optional float diff_focus_gamma = 5 [default = 1.0]; optional float diff_focus_tau = 6 [default = 1.0]; optional float default_weight = 7 [default = 1.0]; } // Configuration for shuffling chunk pool that manages chunk shuffling and // loading. Maps to ShufflingChunkPoolOptions in // csrc/loader/chunk_feed/shuffling_chunk_pool.h message ShufflingChunkPoolConfig { // Size of the chunk shuffle buffer. optional uint64 chunk_pool_size = 1 [default = 100000]; // Threads for ingesting new chunk sources (lightweight). optional uint64 source_ingestion_threads = 3 [default = 1]; // Threads for loading chunk data. optional uint64 chunk_loading_threads = 4 [default = 4]; // Output queue configuration. optional QueueConfig output = 5; // When set to a positive value, enable Hanse single-position sampling. // Probability to accept a chunk is // min(1.0, num_records / hanse_sampling_threshold) ** hanse_sampling_gamma // where num_records is the number of frames in the chunk. optional uint64 hanse_sampling_threshold = 6; // by default, disabled. optional double hanse_sampling_gamma = 7 [default = 1.0]; optional PositionSamplingConfig position_sampling = 11; // Optional output queue for cache hit frames. optional QueueConfig cachehit_output = 8; // Size of the position cache per chunk. optional uint64 position_cache_size = 9; // Threads for caching positions. optional uint64 caching_threads = 10 [default = 1]; } // Configuration for chunk rescorer that adjusts chunk metadata using // tablebases. Maps to ChunkRescorerOptions in // csrc/loader/stages/chunk_rescorer.h message ChunkRescorerConfig { // Number of worker threads for rescoring. optional uint64 threads = 1 [default = 1]; // Output queue configuration. optional QueueConfig output = 2; // Comma-separated list of Syzygy tablebase directories. optional string syzygy_paths = 3; // Policy reshaping temperature. optional float dist_temp = 4 [default = 1.0]; // Policy offset applied during rescoring. optional float dist_offset = 5 [default = 0.0]; // DTZ boost factor when policy adjustments apply. optional float dtz_boost = 6 [default = 0.0]; // Optional conversion target for input format (-1 keeps original). optional int32 new_input_format = 7 [default = -1]; // Optional deblunder threshold for policy adjustments. optional float deblunder_threshold = 8; // Optional deblunder width controlling smoothing around the threshold. optional float deblunder_width = 9; // Comma-separated list of Gaviota tablebase directories. optional string gaviota_paths = 11; // Soft threshold for st_q when converting v6 to v7. optional float st_q_theta = 12 [default = 0.8333333333]; } // Configuration for chunk unpacker that extracts frames from packed chunks. // Maps to ChunkUnpackerOptions in csrc/loader/chunk_feed/chunk_unpacker.h message ChunkUnpackerConfig { // Number of worker threads for unpacking. optional uint64 threads = 1 [default = 1]; // Probability of sampling each position within a chunk. optional float position_sampling_rate = 2 [default = 1.0]; // Number of positions to take from each chunk. optional uint32 position_count = 3; // Output queue configuration. optional QueueConfig output = 4; // Number of positions to prefetch for caching. optional uint32 prefetch_count = 5; // Optional output queue for prefetch cache requests. optional QueueConfig prefetch_output = 6; optional PositionSamplingConfig position_sampling = 7; } // Configuration for shuffling frame sampler using reservoir sampling. // Maps to ShufflingFrameSamplerOptions in csrc/loader/shuffling_frame_sampler.h message ShufflingFrameSamplerConfig { // Number of worker threads. optional uint64 threads = 1 [default = 1]; // Size of sampling reservoir per thread. optional uint64 reservoir_size_per_thread = 2 [default = 1000000]; // Output queue configuration. optional QueueConfig output = 3; } // Configuration for tensor generator that converts frames to batched tensors. // Maps to TensorGeneratorOptions in csrc/loader/tensor_generator.h message TensorGeneratorConfig { // Number of worker threads for tensor generation. optional uint64 threads = 1 [default = 1]; // Batch size for tensor generation. optional uint64 batch_size = 2 [default = 1024]; // Output queue configuration. optional QueueConfig output = 3; } // Configuration for a stage that splits incoming ChunkSources into several // outputs based on a deterministic hash of (chunk source name, index within // source). For each configured output, provides a name, weight that controls // the proportion of indices assigned to it, and queue parameters. message ChunkSourceSplitterConfig { // List of output queue configurations. repeated QueueConfig output = 1; // Relative weights for hash-based assignment (parallel to output). repeated uint64 weight = 2; } // Configuration for simple chunk shuffler that processes one chunk source at a // time and outputs all chunks in shuffled order. message SimpleChunkExtractorConfig { // Output queue configuration. optional QueueConfig output = 1; } // Configuration for join stage that merges multiple input streams. message JoinPositionsConfig { // Output queue configuration. optional QueueConfig output = 1; } // Stage-level configuration providing a name and stage-specific options. message StageConfig { // Unique name used to reference the stage output. optional string name = 1; // Names of upstream stages providing input to this stage. repeated string input = 2; // Field 3 reserved for future output configuration. optional FilePathProviderConfig file_path_provider = 4; optional ChunkSourceLoaderConfig chunk_source_loader = 5; optional ShufflingChunkPoolConfig shuffling_chunk_pool = 6; optional ChunkUnpackerConfig chunk_unpacker = 7; optional ShufflingFrameSamplerConfig shuffling_frame_sampler = 8; optional TensorGeneratorConfig tensor_generator = 9; optional ChunkRescorerConfig chunk_rescorer = 10; optional ChunkSourceSplitterConfig chunk_source_splitter = 11; optional SimpleChunkExtractorConfig simple_chunk_extractor = 12; optional JoinPositionsConfig join_positions = 13; } // Main configuration class for the DataLoader containing all component // configurations. message DataLoaderConfig { // Ordered list of stage configurations comprising the pipeline. repeated StageConfig stage = 1; // Exposed outputs, each with an optional alias used by clients to fetch // data. Expected format: "alias:stage.output", where "alias:" and ".output" // are optional. repeated string output = 2; } ================================================ FILE: proto/export_config.proto ================================================ syntax = "proto2"; package lczero.training; // Configuration for model export settings. message ExportConfig { // Destination filenames where exported models will be saved. // Supports formatting with {datetime} (YYYYMMDD-HHMMSS) and {step} (int). // Example: "/path/to/models/lc0-{datetime}-{step:08d}.pb.gz" repeated string destination_filename = 1; // Training run ID for uploading to training website. Only uploads when set. optional int32 upload_training_run = 2; // Whether to export the SWA model instead of the main model. optional bool export_swa_model = 3; } ================================================ FILE: proto/metrics_config.proto ================================================ syntax = "proto2"; package lczero.training; // Sentinel message for training batch sample type. message TrainingBatch {} // Configuration for weight-based metrics. message WeightsMetric { optional bool rms = 1; } // Configuration for individual metric collection. message MetricConfig { // Name of the metric. optional string name = 1; // Whether to use the SWA model for this metric. optional bool use_swa_model = 2; // Period for metric collection. optional int32 period = 3 [default = 1]; // Whether to use global steps instead of epoch-relative steps. optional bool use_global_steps = 4 [default = false]; // Whether to collect metric after epoch completion. optional bool after_epoch = 5; // Source of data for metric calculation. oneof sample { // Use training batch data. TrainingBatch training_batch = 6; // Name of dataloader output to use. string dataloader_output = 7; // Path to numpy tensor file (.npz). string npz_filename = 8; // Weight-based metrics. WeightsMetric weights = 9; } } // Configuration for metrics collection and export. message MetricsConfig { // Directory path where TensorBoard event files will be written. optional string tensorboard_path = 1; // List of metrics to collect during training. repeated MetricConfig metric = 2; } ================================================ FILE: proto/model_config.proto ================================================ syntax = "proto2"; import "proto/net.proto"; import "proto/hlo.proto"; package lczero.training; message ModelConfig { optional DefaultsConfig defaults = 4; optional EmbeddingConfig embedding = 5; optional EncoderConfig encoder = 6; optional uint32 shared_policy_embedding_size = 7; repeated PolicyHeadConfig policy_head = 8; repeated ValueHeadConfig value_head = 9; repeated MovesLeftHeadConfig movesleft_head = 10; } message DefaultsConfig { optional pblczero.XlaShapeProto.Type compute_dtype = 2; optional pblczero.NetworkFormat.ActivationFunction activation = 4; optional pblczero.NetworkFormat.ActivationFunction ffn_activation = 5; } message EmbeddingConfig { optional uint32 dense_size = 1; optional uint32 embedding_size = 2; optional uint32 dff = 3; } message EncoderConfig { optional uint32 num_blocks = 3; optional uint32 dff = 4; optional uint32 d_model = 5; optional uint32 heads = 6; optional SmolgenConfig smolgen = 9; } message SmolgenConfig { optional uint32 hidden_channels = 1; optional uint32 hidden_size = 2; optional uint32 gen_size = 3; optional pblczero.NetworkFormat.ActivationFunction activation = 5; } message PolicyHeadConfig { optional string name = 1; optional uint32 embedding_size = 2; optional uint32 d_model = 3; } message ValueHeadConfig { optional string name = 1; optional uint32 num_channels = 2; optional bool has_error_output = 3; optional uint32 num_categorical_buckets = 4; } message MovesLeftHeadConfig { optional string name = 1; optional uint32 num_channels = 2; } ================================================ FILE: proto/root_config.proto ================================================ syntax = "proto2"; package lczero.training; import "proto/data_loader_config.proto"; import "proto/model_config.proto"; import "proto/training_config.proto"; import "proto/metrics_config.proto"; import "proto/export_config.proto"; // Root configuration message containing all subsystem configurations. message RootConfig { // Name identifier for this configuration optional string name = 1; // Optional log filename for file-based logging optional string log_filename = 2; // Data loader configuration optional DataLoaderConfig data_loader = 3; // Training configuration optional TrainingConfig training = 4; // Model configuration optional ModelConfig model = 5; // Metrics configuration optional MetricsConfig metrics = 6; // Export configuration optional ExportConfig export = 7; } ================================================ FILE: proto/stage_control.proto ================================================ syntax = "proto2"; package lczero.training; message ShufflingChunkPoolControlRequest { optional bool reset_chunk_anchor = 1; optional string set_chunk_anchor = 2; } message StageControlRequest { optional ShufflingChunkPoolControlRequest chunk_pool_request = 1; } message ShufflingChunkPoolControlResponse { optional string chunk_anchor = 1; optional int32 chunks_since_anchor = 2; } message StageControlResponse { optional ShufflingChunkPoolControlResponse chunk_pool_response = 1; } ================================================ FILE: proto/training_config.proto ================================================ syntax = "proto3"; package lczero.training; // Configuration for training algorithm and parameters. message TrainingConfig { ScheduleConfig schedule = 1; repeated LrSchedule lr_schedule = 2; CheckpointConfig checkpoint = 3; OptimizerConfig optimizer = 4; LossConfig losses = 5; // Maximum gradient norm; set to 0 or omit to disable clipping. float max_grad_norm = 6; // Stochastic Weight Averaging (SWA) configuration. When absent, SWA is // disabled. SWAConfig swa = 7; } message ScheduleConfig { int32 steps_per_network = 1; int32 chunks_per_network = 2; } message OptimizerConfig { oneof optimizer_type { NadamwOptimizerConfig nadamw = 1; NadamOptimizerConfig nadam = 2; SgdOptimizerConfig sgd = 3; } // When set, weights selected by this selector are frozen (no gradient // updates). WeightsSelector freeze_selector = 4; } message LrSchedule { // Optimizer step when this schedule becomes active. int32 starting_step = 1; // Duration of each interval while this schedule is active. Last entry may be // zero to indicate an open-ended tail. repeated uint32 duration_steps = 2; // Learning rate at the beginning of each interval. repeated float lr = 3; enum Transition { CONSTANT = 0; LINEAR = 1; COSINE = 2; } // Transition type to use for each interval. Missing entries default to // CONSTANT. repeated Transition transition = 4; // When true this schedule loops after finishing the final interval. bool loop = 5; } message NadamwOptimizerConfig { float beta_1 = 1; float beta_2 = 2; float epsilon = 3; float weight_decay = 4; // Selector for weight decay. True = apply decay to this category. WeightsSelector decay_selector = 5; } message NadamOptimizerConfig { float beta_1 = 1; float beta_2 = 2; float epsilon = 3; } message SgdOptimizerConfig { float momentum = 1; bool nesterov = 2; } message WeightsSelectorRule { // Glob pattern matched against weight path. string match = 1; // True = include; false = exclude. bool include = 2; } // Selector for which weight categories to include. // Rules are evaluated in order; the first matching rule applies. // If no rule matches, otherwise_include is used. message WeightsSelector { repeated WeightsSelectorRule rule = 1; bool otherwise_include = 2; } message CheckpointConfig { string path = 1; int32 max_to_keep = 2; } message LossConfig { repeated PolicyLossConfig policy = 1; repeated ValueLossConfig value = 2; repeated MovesLeftLossConfig movesleft = 3; repeated ValueErrorLossConfig value_error = 4; repeated ValueCategoricalLossConfig value_categorical = 5; repeated RegularizationLossConfig regularization = 6; } message RegularizationLossConfig { enum RegularizationType { L2 = 0; } RegularizationType type = 1; string metric_name = 2; float weight = 3; // Selector for which weights to regularize. True = include in regularization. WeightsSelector selector = 4; } enum ValueType { RESULT = 0; BEST = 1; PLAYED = 2; ORIG = 3; ROOT = 4; ST = 5; } message OptimisticPolicyConfig { // Name of the value head to use for optimism computation. string value_head_name = 1; // Which value type to use from the training sample. ValueType value_type = 2; // Z-score threshold for optimism (e.g. 2.0). float strength = 3; // Epsilon for numerical stability (e.g. 1e-5). float eps = 4; // Sigmoid scaling factor (e.g. 3.0). float alpha = 5; // If false, prevent gradients from flowing to value and error heads. bool propagate_value_gradients = 6; } message PolicyLossConfig { string head_name = 1; string metric_name = 2; float weight = 3; enum IllegalMoveHandling { TRAIN_TO_ZERO = 0; MASK = 1; } IllegalMoveHandling illegal_moves = 4; enum LossType { LOSS_TYPE_UNSPECIFIED = 0; CROSS_ENTROPY = 1; KL = 2; } // Selects which policy loss implementation to use. Must be specified. LossType type = 5; // Soft policy temperature applied before normalizing targets for KL loss. // Values <= 0 disable the adjustment and keep raw targets. float temperature = 6; OptimisticPolicyConfig optimistic = 7; } message ValueLossConfig { string head_name = 1; string metric_name = 2; float weight = 3; ValueType value_type = 4; } message MovesLeftLossConfig { string head_name = 1; string metric_name = 2; float weight = 3; ValueType value_type = 4; } message ValueErrorLossConfig { string head_name = 1; string metric_name = 2; float weight = 3; ValueType value_type = 4; bool propagate_value_gradients = 5; } message ValueCategoricalLossConfig { string head_name = 1; string metric_name = 2; float weight = 3; ValueType value_type = 4; } // Stochastic Weight Averaging configuration. message SWAConfig { // Number of steps between SWA updates. uint32 period_steps = 1; // Maximum effective number of model snapshots to average. uint32 num_averages = 2; } ================================================ FILE: proto/training_metrics.proto ================================================ syntax = "proto2"; package lczero; // Load metric that accumulates seconds of load time. // Separate proto to support LoadMetricUpdaterProto functionality. message LoadMetricProto { optional string name = 1; optional double load_seconds = 2 [default = 0.0]; optional double total_seconds = 3 [default = 0.0]; } // Statistics metric for integer values. message StatisticsProtoInt64 { optional int64 min = 1 [default = 9223372036854775807]; optional int64 max = 2 [default = -9223372036854775807]; optional int64 sum = 3 [default = 0]; optional int64 count = 4 [default = 0]; optional int64 latest = 5 [default = 0]; } // Statistics metric for double values. message StatisticsProtoDouble { optional string name = 1; optional double min = 2 [default = 1.7976931348623157e+308]; optional double max = 3 [default = -1.7976931348623157e+308]; optional double sum = 4 [default = 0.0]; optional int64 count = 5 [default = 0]; optional double latest = 6 [default = 0.0]; } // Metrics for queue performance monitoring. message QueueMetricProto { optional string name = 1; optional uint64 put_count = 2 [default = 0]; optional uint64 get_count = 3 [default = 0]; optional uint64 drop_count = 4 [default = 0]; optional StatisticsProtoInt64 queue_fullness = 5; optional uint64 queue_capacity = 6 [default = 0]; } message CountMetricProto { optional string name = 1; optional uint64 count = 2 [default = 0]; } // Gauge metric for values that represent current state. message GaugeMetricProto { optional string name = 1; optional uint64 value = 2 [default = 0]; optional uint64 capacity = 3; } // Top-level metrics for the DataLoader. message StageMetricProto { optional string name = 1; repeated LoadMetricProto load_metrics = 3; repeated QueueMetricProto queue_metrics = 4; repeated CountMetricProto count_metrics = 5; repeated GaugeMetricProto gauge_metrics = 11; repeated StatisticsProtoDouble statistics_metrics = 12; optional string last_chunk_key = 8; optional string anchor = 9; } message DataLoaderMetricsProto { repeated StageMetricProto stage_metrics = 1; } ================================================ FILE: pyproject.toml ================================================ [project] name = "lczero-training" version = "0.1.0" description = "Training scripts and data loading for Leela Chess Zero" authors = [{ name = "Leela Chess Zero Team" }] readme = "README.md" requires-python = ">=3.13" dependencies = [ "pybind11>=2.10.0", "numpy>=1.24.0", "textual[dev]>=0.47.0", "protobuf==6.33.5", "mypy>=1.17.1", "pytest>=8.4.1", "anyio>=4.10.0", "jax[cuda12]==0.9.1", "flax>=0.11.1", "optax>=0.2.5", "orbax-checkpoint>=0.11.23", "python-dotenv>=1.1.1", "requests[socks]>=2.32.5", "matplotlib>=3.10.6", "jaxlib==0.9.1", "onnxruntime>=1.23.1", "onnxruntime-gpu>=1.23.0", "tensorboardx>=2.6.4", "graphviz>=0.21", "meson>=1.10.1", "ruff>=0.15.4", ] [project.scripts] lc0-leela2jax = "lczero_training.commands.leela2jax:main" lc0-jax2leela = "lczero_training.commands.jax2leela:main" lc0-describe = "lczero_training.commands.describe_training:main" lc0-test-dataloader = "lczero_training.commands.test_dataloader:main" lc0-dataloader-viz = "lczero_training.commands.dataloader_viz:main" lc0-overfit = "lczero_training.commands.overfit:main" lc0-init = "lczero_training.commands.training_init:main" lc0-eval = "lczero_training.commands.training_eval:main" lc0-tune-lr = "lczero_training.commands.tune_lr:main" lc0-migrate-checkpoint = "lczero_training.commands.migrate_checkpoint:main" lc0-backfill-metrics = "lczero_training.commands.backfill_metrics:main" lc0-train = "lczero_training.commands.train:main" lc0-daemon = "lczero_training.commands.daemon:main" lc0-tui = "lczero_training.commands.tui:main" lc0-weights = "lczero_training.commands.weights_tool:main" [project.optional-dependencies] dev = [ "pytest>=7.0.0", "mypy>=1.0.0", "typing-extensions>=4.0.0", "mypy-extensions>=0.4.0", "pathspec>=0.11.0", ] [build-system] requires = ["setuptools>=64", "pybind11>=2.10.0", "wheel"] build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["src"] [tool.setuptools.package-data] "*" = ["*.so", "*.dll", "*.dylib"] [dependency-groups] dev = [ "grpcio-tools>=1.76.0", "mypy-protobuf>=3.6.0", "textual-dev>=1.7.0", "types-protobuf>=6.30.2.20250809", "types-requests>=2.32.4.20250913", ] [tool.mypy] mypy_path = "src" # TEMPORARY: Disable misc errors due to protobuf bug in PR #23156 # (https://github.com/protocolbuffers/protobuf/pull/23156) # Revert when protobuf fixes __slots__ = () in generated .pyi files disable_error_code = ["misc"] [[tool.mypy.overrides]] module = "orbax.*" ignore_missing_imports = true [[tool.mypy.overrides]] module = "optax" ignore_missing_imports = true [[tool.mypy.overrides]] module = "onnxruntime" ignore_missing_imports = true [[tool.mypy.overrides]] module = "tensorboardX" ignore_missing_imports = true [tool.pytest.ini_options] testpaths = ["src"] python_files = ["test_*.py", "*_test.py"] pythonpath = ["src"] addopts = "-v" [tool.ruff] exclude = ["*_pb2.py*"] line-length = 80 ================================================ FILE: scripts/diff.py ================================================ #!/usr/bin/env python import glob import os import argparse def get_sorted_chunk_ids(dirs): ids = [] for d in dirs: for f in glob.glob(os.path.join(d, "training.*.gz")): ids.append(int(os.path.basename(f).split('.')[-2])) ids.sort(reverse=True) return ids def main(argv): a = get_sorted_chunk_ids([argv.input]) b = get_sorted_chunk_ids(argv.dirs) n = min(argv.wsize, len(a)) diff = set(a[:n]) - set(b) for i in sorted(diff): print('training.{}.gz'.format(i)) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Print diffset of input dir and output dirs.') argparser.add_argument('-i', '--input', type=str, help='input directory') argparser.add_argument('-w', '--wsize', type=int, help='window size') argparser.add_argument('dirs', nargs='+', help='output directories') main(argparser.parse_args()) ================================================ FILE: scripts/fixorder.py ================================================ #!/usr/bin/env python import glob import os import argparse def get_sorted_chunk_ids(dirs): ids = [] for d in dirs: for f in glob.glob(os.path.join(d, "training.*.gz")): ids.append(int(os.path.basename(f).split('.')[-2])) ids.sort() return ids def main(argv): a = get_sorted_chunk_ids([argv.input]) for i in a: os.utime(os.path.join(argv.input, "training.{}.gz".format(i)), None) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Change modification time on training files to match their numeric order.') argparser.add_argument('-i', '--input', type=str, help='input directory') main(argparser.parse_args()) ================================================ FILE: scripts/init.sh ================================================ #!/usr/bin/env bash set -e WINDOWSIZE=80000 ROOT="/work/lc0" echo "Cleaning up data directory" rm -rf $ROOT/data mkdir -v $ROOT/data echo "Hard link $WINDOWSIZE seed files in $ROOT/data" i=1 for file in $(find $ROOT/seed/data-* -name '*.gz' | shuf) do ln $file $ROOT/data/training.$i.gz if [[ $i = $WINDOWSIZE ]] then break fi let i="i + 1" done echo "Set $HOME/.lc0.dat to $WINDOWSIZE" echo $WINDOWSIZE > $HOME/.lc0.dat rm -rf $ROOT/split mkdir -vp $ROOT/split/{test,train} let testsize="$WINDOWSIZE / 10" let trainsize="$WINDOWSIZE - $testsize" echo "Create $ROOT/split/test ($testsize) and $ROOT/split/train ($trainsize)" ls -1 -U $ROOT/data | head -n $trainsize | xargs -I{} ln $ROOT/data/{} $ROOT/split/train/{} ls -1 -U $ROOT/data | tail -n $testsize | xargs -I{} ln $ROOT/data/{} $ROOT/split/test/{} ================================================ FILE: scripts/initsplit.py ================================================ #!/usr/bin/env python import glob import os import argparse def get_sorted_chunk_ids(dirs): ids = [] for d in dirs: for f in glob.glob(os.path.join(d, "training.*.gz")): ids.append(int(os.path.basename(f).split('.')[-2])) ids.sort(reverse=True) return ids def main(argv): a = get_sorted_chunk_ids([argv.input]) n = min(argv.wsize, len(a)) for i in sorted(a[:n]): if i % 100 >= 90: os.link(os.path.join(argv.input, "training.{}.gz".format(i)), os.path.join(argv.output, "test/training.{}.gz".format(i))) else: os.link( os.path.join(argv.input, "training.{}.gz".format(i)), os.path.join(argv.output, "train/training.{}.gz".format(i))) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Link input to test/train subdirectories of output in 10:90 ratio.') argparser.add_argument('-i', '--input', type=str, help='input directory') argparser.add_argument( '-w', '--wsize', type=int, help= 'window size - should be padded a bit to ensure both sides of split exceed fraction of target' ) argparser.add_argument('-o', '--output', type=str, help='output directory') main(argparser.parse_args()) ================================================ FILE: scripts/inittrainingname.py ================================================ #!/usr/bin/env python import glob import os import argparse def get_sorted_chunk_ids(dirs): ids = [] for d in dirs: for f in glob.glob(os.path.join(d, "game_*.gz")): if os.path.basename(f) == "game_000000.gz": ids.append((0, f)) else: ids.append((int( os.path.basename(f).split('.')[-2].split('_')[-1].lstrip( "0")), f)) ids.sort() return ids def main(argv): a = get_sorted_chunk_ids([argv.input]) for (i, f) in a: os.rename( f, os.path.join(argv.input, "training.{}.gz".format(i + argv.base))) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Rename files generated by lc0 to have training names.') argparser.add_argument('-i', '--input', type=str, help='input directory') argparser.add_argument('-b', '--base', type=int, default=0, help='base value for names') main(argparser.parse_args()) ================================================ FILE: scripts/pack.py ================================================ #!/usr/bin/env python3 import glob import os import argparse import gzip import bz2 import struct import numpy as np from multiprocessing import Pool RECORD_SIZE = 8276 def get_uncompressed_size(filename): with open(filename, 'rb') as f: f.seek(-4, 2) return struct.unpack('I', f.read(4))[0] def get_sorted_chunk_ids(dirs): ids = [] for d in dirs: for f in glob.glob(os.path.join(d, "training.*.gz")): ids.append(int(os.path.basename(f).split('.')[-2])) ids.sort() return ids def pack(ids): plies = [] fout_name = os.path.join(argv.output, '{}-{}.bz2'.format(ids[0], ids[-1])) with bz2.open(fout_name, 'xb') as fout: for tid in ids: fin_name = os.path.join(argv.input, 'training.{}.gz'.format(tid)) plies.append(get_uncompressed_size(fin_name) // RECORD_SIZE) with gzip.open(fin_name, 'rb') as fin: fout.write(fin.read()) if argv.remove: os.remove(fin_name) plylist = np.array(plies, dtype=np.int16) size = struct.pack('I', len(plylist) * 2) fout.write(plylist.tobytes()) fout.write(size) print("Written '{}' {} records".format(fout_name, np.sum(plies))) def main(): if not os.path.exists(argv.output): os.makedirs(argv.output) print("Created directory '{}'".format(argv.output)) ids = get_sorted_chunk_ids([argv.input]) n = len(ids) // argv.number m = argv.number print("Processing {} ids, {} - {} ({}x{})".format(len(ids), ids[0], ids[-1], n, m)) packs = [ids[i * m:i * m + m] for i in range(n)] # add remaining ids to last pack packs[-1] += ids[n * m + m:] with Pool() as pool: pool.map(pack, packs) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Repack training.*.gz files in batches of bz2 format.') argparser.add_argument('-i', '--input', type=str, help='input directory') argparser.add_argument('-o', '--output', type=str, help='output directory') argparser.add_argument('-r', '--remove', action='store_true', help='remove input files while processing') argparser.add_argument('-n', '--number', type=int, default=1000, help='number of games to repack per bz2 package') argv = argparser.parse_args() main() ================================================ FILE: scripts/purge.py ================================================ #!/usr/bin/env python import glob import os import argparse def get_sorted_chunk_ids(dirs): ids = [] for d in dirs: for f in glob.glob(os.path.join(d, "training.*.gz")): ids.append(int(os.path.basename(f).split('.')[-2])) ids.sort(reverse=True) return ids def main(argv): a = get_sorted_chunk_ids([argv.input]) n = min(argv.wsize, len(a)) for i in a[n:]: os.remove(os.path.join(argv.input, "training.{}.gz".format(i))) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Delete from input not in window.') argparser.add_argument('-i', '--input', type=str, help='input directory') argparser.add_argument('-w', '--wsize', type=int, help='window size') main(argparser.parse_args()) ================================================ FILE: scripts/rescore.sh ================================================ #!/usr/bin/env bash set -e ROOT="/work/lc0/dev2" RESCORER="$HOME/bin/rescorer" function usage() { echo "Rescores stuff" echo "" echo "./rescore.sh" echo " -h --help" echo "" echo "Example: ./rescore.sh" echo "" } while [ "$1" != "" ] do PARAM=`echo $1 | awk -F= '{print $1}'` VALUE=`echo $1 | awk -F= '{print $2}'` case $PARAM in -h | --help) usage exit ;; *) echo "ERROR: unknown parameter \"$PARAM\"" usage exit 1 ;; esac shift done rescore() { unbuffer $RESCORER rescore --threads=4 --syzygy-paths=/work/lc0/syzygy/:/wdl/syzygy/wdl/:/wdl/syzygy/dtz/ --input="$ROOT/data-staged" --output="$ROOT/data-rescored" 2>&1 | tee "$ROOT/rescore-logs/$(date +%Y%m%d-%H%M%S).log" } while true do rescore echo -n "." sleep 10 done ================================================ FILE: scripts/shuffle.py ================================================ #!/usr/bin/python3 import gzip import sys import glob import os import random from multiprocessing import Pool import tqdm merge_files = 100 processes = 8 shuffle = True record_length = 8292 def split(a, n): k, m = divmod(len(a), n) return [a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)] def positions(chunk): pos = [] i = 0 while record_length * i < len(chunk): pos.append(chunk[record_length * i:record_length * (i + 1)]) i += 1 return pos def shuffle(files): data = [] for filename in files: with gzip.open(filename, 'rb') as f: data.extend(positions(f.read())) if shuffle: random.shuffle(data) for d in data: if d[0] != 0x04: print(files) raise ValueError('Wrong training data format, not V4') new_file = list(os.path.splitext(files[0])) new_file[0] += '_shuffled' new_file = ''.join(new_file) new_file_temp = new_file + '.temp' with gzip.open(new_file_temp, 'wb', compresslevel=9) as f: for d in data: f.write(d) # For interrupt safety, make sure not to write partial chunks. os.rename(new_file_temp, new_file) for filename in files: os.remove(filename) if __name__ == "__main__": if len(sys.argv) != 2: print('Expected one argument, got {}'.format(len(sys.argv) - 1)) exit(1) s = sys.argv[1] files = glob.glob(os.path.join(sys.argv[1], '*.gz')) files = [f for f in files if '_shuffled' not in f] print('Found {} files'.format(len(files))) if len(files) == 0: exit(1) files = split(files, len(files) // merge_files) pool = Pool(processes) for _ in tqdm.tqdm(pool.imap_unordered(shuffle, files), total=len(files)): pass ================================================ FILE: scripts/split.sh ================================================ #!/usr/bin/env bash RECORDSIZE=8276 # size in bytes of a record (s, pi, v) function usage() { echo "Watches a directory and copies data to train/test set" echo "" echo "./split.sh" echo " -h --help" echo " -i --input The directory where chunks arrive" echo " -o --output The output directory" echo " -n --window window size of test + train" echo " -t --train The training percentage in {1,...,100}" echo " -l --latest The file to store the largest training file number." echo "" echo "Example: ./split.sh -i=/tmp -o=/out -n=2000 -t=95" echo "" } while [ "$1" != "" ] do PARAM=`echo $1 | awk -F= '{print $1}'` VALUE=`echo $1 | awk -F= '{print $2}'` case $PARAM in -h | --help) usage exit ;; -i | --input) INPUTDIR=$VALUE ;; -o | --output) TESTDIR="$VALUE/test" TRAINDIR="$VALUE/train" ;; -n | --window) WINSIZE=$VALUE ;; -t | --train) TRAINPCT=$VALUE ;; -l | --latest) LATESTFILE=$VALUE ;; *) echo "ERROR: unknown parameter \"$PARAM\"" usage exit 1 ;; esac shift done if [ -z "$LC0LOCKFILE" ] then echo "env var LC0LOCKFILE not set" exit 1 fi # clear test and train split dirs if [ ! -d "$TESTDIR" ] || [ ! -d "$TRAINDIR" ] then rm -rf "$TESTDIR" "$TRAINDIR" mkdir -vp "$TESTDIR" "$TRAINDIR" fi let n_test="$(ls $TESTDIR | wc -l)" let n_train="$(ls $TRAINDIR | wc -l)" let n="$n_test + $n_train" let overhead="$WINSIZE / 10" let max="$WINSIZE + $overhead + 200" max_train=$(echo "scale=1;($TRAINPCT / 100) * $max" | bc | cut -d'.' -f1) max_test=$(echo "scale=1;(1 - $TRAINPCT / 100) * $max" | bc | cut -d'.' -f1) overhead_train=$(echo "scale=1;($TRAINPCT / 100) * $overhead" | bc | cut -d'.' -f1) overhead_test=$(echo "scale=1;(1 - $TRAINPCT / 100) * $overhead" | bc | cut -d'.' -f1) latest=0 if [ -f $LATESTFILE ] then latest=$(cat $LATESTFILE) fi echo "" echo "start splitter, found $n games, $n_test test, $n_train train" echo " max chunks: $max" echo " max test: $max_test, trim_by: $overhead_test" echo " max train: $max_train, trim_by: $overhead_train" echo "" process() { local dir=$1 local file=$2 if [[ $file = training.*.gz ]] then # compute basic file integrity check size=$(zcat $dir/$file | wc -c) let rem="size % $RECORDSIZE" if [[ $size -eq 0 ]] || [[ $rem -ne 0 ]] then echo -n "X" return fi # new file, put hard link in correct directory let "n++" id=$(echo $file | cut -d'.' -f 2) let hash_index="$id % 100 + 1" if [ $id -gt $latest ] then latest=$id if [ -f $LATESTFILE ] then echo $latest > $LATESTFILE fi fi if [ $hash_index -gt $TRAINPCT ] then let "n_test++" target=$TESTDIR/$file echo -n "T" else let "n_train++" target=$TRAINDIR/$file echo -n "*" fi ln $dir/$file $target # exceeding max buffer size for either, lock and remove overhead as appropriate if [ $n_test -gt $max_test ] || [ $n_train -gt $max_train ] then ( flock -e 200 if [ $n_test -gt $max_test ] then ls -rt $TESTDIR | head -n $overhead_test | xargs -I{} rm -f $TESTDIR/{} echo -n "-" fi if [ $n_train -gt $max_train ] then ls -rt $TRAINDIR | head -n $overhead_train | xargs -I{} rm -f $TRAINDIR/{} echo -n "_" fi ) 200>$LC0LOCKFILE if [ $n_test -gt $max_test ] then let "n -= $overhead_test" let "n_test -= $overhead_test" echo -n "-" fi if [ $n_train -gt $max_train ] then let "n -= $overhead_train" let "n_train -= $overhead_train" echo -n "_" fi fi fi } echo "processing '$INPUTDIR'" for file in $(./diff.py -i $INPUTDIR -w $WINSIZE $TRAINDIR $TESTDIR) do process $INPUTDIR $file done echo -e "\nmonitoring '$INPUTDIR'" inotifywait -q -m -e moved_to -e close_write $INPUTDIR | mbuffer -m 10M | while read dir event file do if [ -f "$TESTDIR/$file" ] || [ -f "$TRAINDIR/$file" ] then continue fi process $INPUTDIR $file done ================================================ FILE: scripts/stage.sh ================================================ #!/usr/bin/env bash set -e function usage() { echo "Moves arriving data to a directory so rescorer can assume all files are complete" echo "" echo "./stage.sh" echo " -h --help" echo " -i --input The monitoring directory" echo " -o --output The directory where output should go" echo "" echo "Example: ./stage.sh -i data -o data-staged" echo "" } while [ "$1" != "" ] do PARAM=`echo $1 | awk -F= '{print $1}'` VALUE=`echo $1 | awk -F= '{print $2}'` case $PARAM in -h | --help) usage exit ;; -i | --input) INPUTDIR=$VALUE ;; -o | --output) OUTPUTDIR=$VALUE ;; *) echo "ERROR: unknown parameter \"$PARAM\"" usage exit 1 ;; esac shift done echo "start data monitor for $INPUTDIR" inotifywait -m -e moved_to -e close_write $INPUTDIR | mbuffer -m 10M | while read dir events file do if [[ $file = *.gz ]] then echo -n "." mv "$INPUTDIR/$file" "$OUTPUTDIR/" #else #echo "ignoring ${file} ($events)" fi done ================================================ FILE: scripts/unpack.py ================================================ #!/usr/bin/env python3 import os import argparse import gzip import bz2 import numpy as np import pickle import struct from pack import RECORD_SIZE def unpack(filepath): front, back = os.path.basename(filepath).split('-') back = back.split('.')[0] first = int(front) last = int(back) num_chunks = last - first + 1 buf = bz2.BZ2File(filepath, 'rb').read() size = struct.unpack('I', buf[-4:])[0] plylist = np.frombuffer(buf[-4 - size:-4], dtype=np.int16) data = np.frombuffer(buf[:-4 - size], dtype=np.int8).reshape(-1, RECORD_SIZE) assert (num_chunks == len(plylist)) begin = 0 for i, plies in enumerate(plylist): end = begin + plies filename = os.path.join(argv.output, "training.{}.gz".format(i + first)) with gzip.open(filename, 'wb') as f: for row in data[begin:end]: f.write(row) begin = end print("Written {} chunks".format(num_chunks)) def main(): if not os.path.exists(argv.output): os.makedirs(argv.output) print("Created directory '{}'".format(argv.output)) unpack(argv.input) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Unpack *-*.bz2 file into gz chunks.') argparser.add_argument('-i', '--input', type=str, help='input file') argparser.add_argument('-o', '--output', type=str, help='output directory') argv = argparser.parse_args() main() ================================================ FILE: scripts/upload.sh ================================================ #!/usr/bin/env bash set -e function usage() { echo "Uploads a network with NxM prefix, where N=filters and M=blocks" echo "" echo "./upload.sh" echo " -h --help" echo " -u --upload The upload url" echo " -d --netdir The directory where new networks arrive" echo " -f --filters Number of filters" echo " -b --blocks Number of blocks" echo "" echo "Example: ./upload.sh -d=/tmp -u=http://upload.me -f=64 -b=6" echo "" } while [ "$1" != "" ] do PARAM=`echo $1 | awk -F= '{print $1}'` VALUE=`echo $1 | awk -F= '{print $2}'` case $PARAM in -h | --help) usage exit ;; -u | --upload) UPLOADURL=$VALUE ;; -d | --netdir) NETDIR=$VALUE ;; -f | --filters) FILTERS=$VALUE ;; -b | --blocks) BLOCKS=$VALUE ;; *) echo "ERROR: unknown parameter \"$PARAM\"" usage exit 1 ;; esac shift done netarch="${FILTERS}x${BLOCKS}" echo "start upload monitor for $netarch*.gz" inotifywait -m -e moved_to -e close_write $NETDIR | while read dir events file do if [[ $file = ${netarch}*.gz ]] then echo "uploading ${file} ($events)" curl -s -F "file=@${dir}/${file}" -F "training_id=1" -F "layers=${BLOCKS}" -F "filters=${FILTERS}" $UPLOADURL & else echo "ignoring ${file} ($events)" fi done ================================================ FILE: src/lczero_training/__init__.py ================================================ """Leela Chess Zero training package.""" ================================================ FILE: src/lczero_training/_lczero_training.pyi ================================================ # ABOUTME: Type stubs for C++ DataLoader PyBind11 bindings. # ABOUTME: Provides type annotations for _lczero_training compiled module. from typing import List, Optional, Tuple import numpy as np from proto.data_loader_config_pb2 import DataLoaderConfig from proto.stage_control_pb2 import StageControlRequest, StageControlResponse class TensorBase: def shape(self) -> List[int]: ... def strides(self) -> List[int]: ... def element_size(self) -> int: ... def py_format(self) -> str: ... class DataLoader: def __init__(self, config: DataLoaderConfig | bytes) -> None: ... def add_stages(self, config: DataLoaderConfig | bytes) -> None: ... def send_control_message( self, request: StageControlRequest | bytes ) -> List[Tuple[str, StageControlResponse]]: ... def start(self) -> None: ... def get_next(self, alias: str = "") -> Tuple[np.ndarray, ...]: ... def maybe_get_next( self, alias: str = "" ) -> Optional[Tuple[np.ndarray, ...]]: ... def stop(self) -> None: ... def get_bucket_metrics( self, time_period: int, include_pending: bool ) -> Tuple[bytes, float]: ... def get_aggregate_ending_now( self, duration_seconds: float, include_pending: bool ) -> Tuple[bytes, float]: ... ================================================ FILE: src/lczero_training/commands/__init__.py ================================================ """Command entrypoint scaffolding and shared CLI helpers. This package will host thin wrappers for individual tools (convert, training, daemon, tui) as they are extracted from nested module __main__ dispatchers in subsequent phases. Phase 1 provides common helpers for consistent logging and argument handling across commands without changing behaviour. """ from .common import ( add_logging_arguments, configure_root_logging, parse_log_level, ) __all__ = [ "configure_root_logging", "add_logging_arguments", "parse_log_level", ] ================================================ FILE: src/lczero_training/commands/backfill_metrics.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging from lczero_training.training.backfill_metrics import backfill_metrics def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Backfill metrics for existing checkpoints." ) parser.add_argument( "--config", type=str, required=True, help="Path to the RootConfig textproto config.", ) parser.add_argument( "--metrics", nargs="+", required=True, help="Names of metrics to backfill (must be NPZ metrics).", ) parser.add_argument( "--min-step", type=int, help="Minimum checkpoint step (inclusive) to process.", ) parser.add_argument( "--max-step", type=int, help="Maximum checkpoint step (inclusive) to process.", ) parser.add_argument( "--migration-config", help=( "Path to a CheckpointMigrationConfig textproto file. " "If provided, checkpoints will be migrated before evaluation." ), ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Lazy import to avoid heavy deps unless executing the command. backfill_metrics( config_path=args.config, metric_names=args.metrics, min_step=args.min_step, max_step=args.max_step, migration_config_path=args.migration_config, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/common.py ================================================ import argparse import logging import os import sys _DEFAULT_FORMAT = ( "%(levelname).1s%(asctime)s.%(msecs)03d %(name)s " "%(filename)s:%(lineno)d] %(message)s" ) _DEFAULT_DATEFMT = "%m%d %H:%M:%S" def configure_root_logging(level: int | str = logging.INFO) -> None: """Configure root logging with a consistent, terse format. - Matches existing project format used in module __main__ files. - Respects explicit level passed as int or name (e.g. "INFO"). - Uses stderr by default. - Forces reconfiguration to avoid duplicate handlers during nested runs. """ resolved_level = parse_log_level(level) logging.basicConfig( level=resolved_level, format=_DEFAULT_FORMAT, datefmt=_DEFAULT_DATEFMT, stream=sys.stderr, force=True, ) def parse_log_level(level: int | str) -> int: """Parse log level from int or string, with sane defaults. Accepts numeric levels or case-insensitive names like "DEBUG". Falls back to INFO on invalid input. """ if isinstance(level, int): return level if isinstance(level, str): name = level.strip().upper() return getattr(logging, name, logging.INFO) return logging.INFO def add_logging_arguments(parser: argparse.ArgumentParser) -> None: """Add common logging CLI arguments to a parser. Does not enable the flags by default; commands may opt-in and then call configure_root_logging(parse_log_level(args.log_level)) if present. """ parser.add_argument( "--log-level", default=os.environ.get("LCZERO_LOG_LEVEL", "INFO"), help="Logging level (DEBUG, INFO, WARNING, ERROR).", ) ================================================ FILE: src/lczero_training/commands/daemon.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging from lczero_training.daemon.daemon import TrainingDaemon def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run the training daemon.") parser.add_argument("--memory-profile-dir", default=None) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) daemon = TrainingDaemon(memory_profile_dir=args.memory_profile_dir) daemon.run() return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/dataloader_viz.py ================================================ import argparse import sys from google.protobuf import text_format from graphviz import Digraph # type: ignore from lczero_training.commands import configure_root_logging from proto.root_config_pb2 import RootConfig def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Visualize the data loader pipeline as a graph." ) parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--output", type=str, required=True, help="Output path for the visualization (.svg or .png).", ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging() parser = _build_parser() args = parser.parse_args(argv) config = RootConfig() with open(args.config, "r") as f: text_format.Parse(f.read(), config) dot = Digraph(comment="DataLoader Pipeline") dot.attr(rankdir="TB") dot.attr( "node", shape="box", fontname="monospace", fontsize="10", labeljust="l", ) stage_names = set() for stage in config.data_loader.stage: stage_names.add(stage.name) stage_text = text_format.MessageToString(stage, as_one_line=False) br_tag = '
' escaped_text = ( stage_text.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace("\n", br_tag) ) label = f"<{escaped_text}>" dot.node(stage.name, label=label, shape="box") for input_spec in stage.input: parts = input_spec.split(".", 1) source_stage = parts[0] if len(parts) == 2: dot.edge(source_stage, stage.name, label=parts[1]) else: dot.edge(source_stage, stage.name) for output_spec in config.data_loader.output: parts = output_spec.split(":", 1) if len(parts) == 2: alias, source = parts else: alias = output_spec source = output_spec source_parts = source.split(".", 1) source_stage = source_parts[0] dot.node( f"output_{alias}", label=f"Output: {alias}", shape="ellipse", style="filled", fillcolor="lightblue", ) if len(source_parts) == 2: dot.edge(source_stage, f"output_{alias}", label=source_parts[1]) else: dot.edge(source_stage, f"output_{alias}") output_format = args.output.rsplit(".", 1)[-1].lower() if output_format not in ("svg", "png"): output_format = "svg" dot.render( outfile=args.output, format=output_format, cleanup=True, engine="dot" ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/describe_training.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Describe a trained model.") parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--shapes", action="store_true", help="Dump model shapes.", ) parser.add_argument( "--values", action="store_true", help="Dump model values.", ) parser.add_argument( "--weight_paths", action="store_true", help="List all weight paths.", ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Import on demand to avoid importing heavy deps on --help. from lczero_training.training.describe import describe describe( config_filename=args.config, shapes=args.shapes, values=args.values, weight_paths=args.weight_paths, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/jax2leela.py ================================================ import argparse import gzip import logging import os import sys import orbax.checkpoint as ocp from google.protobuf import text_format from lczero_training.commands import configure_root_logging from lczero_training.convert.jax_to_leela import ( LeelaExportOptions, jax_to_leela, ) from lczero_training.training.state import TrainingState from proto.root_config_pb2 import RootConfig def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Export JAX checkpoint to Leela format." ) parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--output", type=str, required=True, help="Output path for the Leela network (.pb.gz).", ) parser.add_argument( "--export-swa", action="store_true", help="Export SWA model instead of regular model_state.", ) parser.add_argument( "--min-version", type=str, default="0.31", help="Minimum lc0 version for exported network (default: 0.31).", ) return parser def jax2leela( config_filename: str, output_path: str, export_swa: bool, min_version: str, ) -> None: config = RootConfig() logging.info("Reading configuration from %s", config_filename) with open(config_filename, "r") as f: text_format.Parse(f.read(), config) if not config.training.checkpoint.path: logging.error("Checkpoint path must be set in the configuration.") sys.exit(1) logging.info("Loading checkpoint from %s", config.training.checkpoint.path) checkpoint_mgr = ocp.CheckpointManager( config.training.checkpoint.path, options=ocp.CheckpointManagerOptions(create=True), ) empty_state = TrainingState.new_from_config( model_config=config.model, training_config=config.training, ) restored_state = checkpoint_mgr.restore( checkpoint_mgr.latest_step(), args=ocp.args.PyTreeRestore(empty_state) ) assert isinstance(restored_state, TrainingState) logging.info( "Restored checkpoint at step %d", restored_state.jit_state.step ) if export_swa: if restored_state.jit_state.swa_state is None: logging.error( "SWA export requested but SWA state is None in checkpoint." ) sys.exit(1) export_state = restored_state.jit_state.swa_state logging.info("Exporting SWA model") else: export_state = restored_state.jit_state.model_state logging.info("Exporting regular model") options = LeelaExportOptions( min_version=min_version, num_heads=restored_state.num_heads, license=None, training_steps=restored_state.jit_state.step, ) logging.info("Converting to Leela format") net = jax_to_leela(jax_weights=export_state, export_options=options) logging.info("Serializing network") network_bytes = gzip.compress(net.SerializeToString()) os.makedirs(os.path.dirname(output_path), exist_ok=True) logging.info("Writing network to %s", output_path) with open(output_path, "wb") as f: f.write(network_bytes) logging.info("Export complete") def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) jax2leela( config_filename=args.config, output_path=args.output, export_swa=args.export_swa, min_version=args.min_version, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/leela2jax.py ================================================ import argparse import sys def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Convert Leela Zero weights to JAX format." ) parser.add_argument( "input", type=str, help="Path to the input Lc0 weights file." ) parser.add_argument( "--output-model-config", type=str, help="Output path to the ModelConfig textproto.", ) parser.add_argument( "--weights-dtype", default="F32", type=str, help="The data type of the weights.", ) parser.add_argument( "--compute-dtype", default="F32", type=str, help="The data type for computation.", ) parser.add_argument( "--print-model-config", action="store_true", help="Print the ModelConfig textproto to stdout.", ) parser.add_argument( "--output-serialized-jax", type=str, help="Path to save the output JAX serialized state.", ) parser.add_argument( "--output-leela-verification", type=str, help=( "Path to save the round-trip converted Leela network (.pb.gz) for " "verification." ), ) return parser def main(argv: list[str] | None = None) -> int: parser = _build_parser() args = parser.parse_args(argv) # Import on demand to avoid importing heavy deps on --help. from lczero_training.convert.leela_to_jax import ( leela_to_jax_files, ) leela_to_jax_files( input_path=args.input, weights_dtype=args.weights_dtype, compute_dtype=args.compute_dtype, output_modelconfig=args.output_model_config, output_serialized_jax=args.output_serialized_jax, output_leela_verification=args.output_leela_verification, print_modelconfig=args.print_model_config, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/migrate_checkpoint.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Migrate a checkpoint to a new training state." ) parser.add_argument( "--config", type=str, required=True, help="Path to the RootConfig textproto config.", ) parser.add_argument( "--new_checkpoint", help=( "Path to save the new checkpoint to. If not set, the tool only " "checks whether the migration rules fully cover the differences." ), ) parser.add_argument( "--overwrite", action="store_true", help="If set, allows overwriting existing checkpoint.", ) parser.add_argument( "--rules_file", help=( "Path to a CheckpointMigrationConfig textproto file containing " "the migration rules." ), ) parser.add_argument( "--serialized-model", action="store_true", default=False, help="Use serialized state for a model.", ) parser.add_argument( "--checkpoint_step", type=int, help=( "If set, use this step when loading from old checkpoint instead " "of the latest." ), ) parser.add_argument( "--new_checkpoint_step", type=int, help=( "If set, use this step when saving the new checkpoint instead of " "copying the old step." ), ) parser.add_argument("--dump_source_paths", action="store_true") parser.add_argument("--dump_destination_paths", action="store_true") return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Lazy import to avoid heavy deps unless executing the command. from lczero_training.training.migrate_checkpoint import ( migrate_checkpoint, ) migrate_checkpoint( config=args.config, new_checkpoint=args.new_checkpoint, overwrite=args.overwrite, rules_file=args.rules_file, serialized_model=args.serialized_model, checkpoint_step=args.checkpoint_step, new_checkpoint_step=args.new_checkpoint_step, dump_source_paths=args.dump_source_paths, dump_destination_paths=args.dump_destination_paths, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/overfit.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Run an overfitting test on a single batch." ) parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--num-steps", type=int, required=True, help="Number of training steps to run on the fixed batch.", ) parser.add_argument( "--coin-flip", action="store_true", help=( "Train on two batches: first train batch A while evaluating batch B, then vice versa." ), ) parser.add_argument( "--csv-file", type=str, help="Optional path to write step-by-step overfit results.", ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Import on demand to avoid importing heavy deps on --help. from lczero_training.training.overfit import overfit overfit( config_filename=args.config, num_steps=args.num_steps, coin_flip=args.coin_flip, csv_file=args.csv_file, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/test_dataloader.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description=( "Fetch batches from the data loader to measure latency and throughput." ) ) parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--num-batches", type=int, default=10, help="Number of batches to fetch from the data loader.", ) parser.add_argument( "--npz-output", type=str, help="Optional path to store fetched batches as an .npz archive.", ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Import on demand to avoid importing heavy deps on --help. from lczero_training.training.dataloader_probe import probe_dataloader probe_dataloader( config_filename=args.config, num_batches=args.num_batches, npz_output=args.npz_output, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/train.py ================================================ import argparse import datetime import gzip import logging import os import sys import orbax.checkpoint as ocp from flax import nnx from google.protobuf import text_format from lczero_training.commands import configure_root_logging from lczero_training.convert.jax_to_leela import ( LeelaExportOptions, jax_to_leela, ) from lczero_training.dataloader import make_dataloader from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel from lczero_training.training.lr_schedule import make_lr_schedule from lczero_training.training.optimizer import make_gradient_transformation from lczero_training.training.state import TrainingState from lczero_training.training.training import Training, from_dataloader from proto.root_config_pb2 import RootConfig def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Start a training run.") parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) return parser def train(config_filename: str) -> None: config = RootConfig() logging.info("Reading configuration from proto file") with open(config_filename, "r") as f: text_format.Parse(f.read(), config) if config.training.checkpoint.path is None: logging.error("Checkpoint path must be set in the configuration.") sys.exit(1) checkpoint_mgr = ocp.CheckpointManager( config.training.checkpoint.path, options=ocp.CheckpointManagerOptions( create=True, ), ) logging.info("Creating state from configuration") empty_state = TrainingState.new_from_config( model_config=config.model, training_config=config.training, ) logging.info("Restoring checkpoint") training_state = checkpoint_mgr.restore( None, args=ocp.args.PyTreeRestore(empty_state) ) logging.info("Restored checkpoint") model, _ = nnx.split( LczeroModel(config=config.model, rngs=nnx.Rngs(params=42)) ) assert isinstance(training_state, TrainingState) jit_state = training_state.jit_state lr_sched = make_lr_schedule(config.training.lr_schedule) optimizer_tx = make_gradient_transformation( config.training.optimizer, max_grad_norm=getattr(config.training, "max_grad_norm", 0.0), lr_schedule=lr_sched, ) training = Training( optimizer_tx=optimizer_tx, graphdef=model, loss_fn=LczeroLoss(config=config.training.losses), swa_config=( config.training.swa if config.training.HasField("swa") else None ), ) new_state = training.run( jit_state, from_dataloader(make_dataloader(config.data_loader)), config.training.schedule.steps_per_network, ) if config.export.destination_filename: date_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") logging.info("Exporting network") options = LeelaExportOptions( min_version="0.28", num_heads=training_state.num_heads, license=None, training_steps=new_state.step, ) export_state = ( new_state.swa_state if config.export.export_swa_model else new_state.model_state ) assert isinstance(export_state, nnx.State) net = jax_to_leela(jax_weights=export_state, export_options=options) network_bytes = gzip.compress(net.SerializeToString()) for destination_template in config.export.destination_filename: destination = destination_template.format( datetime=date_str, step=new_state.step ) logging.info(f"Writing network to {destination}") os.makedirs(os.path.dirname(destination), exist_ok=True) with open(destination, "wb") as f: f.write(network_bytes) logging.info(f"Finished writing network to {destination}") def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) train(config_filename=args.config) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/training_eval.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Evaluate a trained model.") parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--num-samples", type=int, help="Number of samples to evaluate.", ) parser.add_argument( "--batch-size", type=int, help="Override batch size from data loader config.", ) parser.add_argument( "--dump-stdout", action="store_true", help="Dump input/output tensors to stdout.", ) parser.add_argument( "--dump-file", type=str, help="Dump input/output tensors to specified file.", ) parser.add_argument( "--dump-shelve", type=str, help="Dump input/output tensors to specified shelve database.", ) parser.add_argument( "--dump-json", type=str, help="Dump input/output tensors to specified JSON file.", ) parser.add_argument( "--onnx-model", type=str, help="Path to an ONNX model to compare against JAX outputs.", ) parser.add_argument( "--no-softmax-jax-wdl", action="store_true", help="Disable softmaxing the JAX WDL head before comparison.", ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Lazy import to keep --help responsive and avoid heavy deps unless needed. from lczero_training.training.eval import eval as eval_fn eval_fn( config_filename=args.config, num_samples=args.num_samples, batch_size_override=args.batch_size, dump_to_stdout=args.dump_stdout, dump_to_file=args.dump_file, dump_to_shelve=args.dump_shelve, dump_to_json=args.dump_json, onnx_model=args.onnx_model, softmax_jax_wdl=not args.no_softmax_jax_wdl, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/training_init.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Initialize a new training run from a config file." ) parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--lczero_model", type=str, help="Path to an existing lczero model to start from.", ) parser.add_argument( "--seed", type=int, default=42, help="Seed for initializing model parameters.", ) parser.add_argument( "--dry-run", action="store_true", help="Skip checkpoint creation.", ) parser.add_argument( "--swa_initial_nets", type=int, default=0, help="Initial value for num_averages in SWA state.", ) parser.add_argument( "--override_training_steps", type=int, help="Override training step number.", ) parser.add_argument( "--overwrite", action="store_true", help="Allow overwriting existing checkpoint.", ) parser.add_argument( "--no-copy-swa", action="store_true", help="Don't copy model weights to SWA state.", ) parser.add_argument( "--ignore-config-mismatch", action="store_true", help="Ignore lczero model config mismatch.", ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Lazy import to keep --help responsive and avoid heavy deps unless needed. from lczero_training.training.init import init init( config_filename=args.config, lczero_model=args.lczero_model, seed=args.seed, dry_run=args.dry_run, swa_initial_nets=args.swa_initial_nets, override_training_steps=args.override_training_steps, overwrite=args.overwrite, no_copy_swa=args.no_copy_swa, ignore_config_mismatch=args.ignore_config_mismatch, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/tui.py ================================================ import argparse import logging import sys import anyio from lczero_training.commands import configure_root_logging from lczero_training.tui.app import TrainingTuiApp def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Training TUI runner") TrainingTuiApp.add_arguments(parser) return parser async def _amain(args: argparse.Namespace) -> None: app = TrainingTuiApp(args=args) await app.run_async() def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) anyio.run(_amain, args) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/tune_lr.py ================================================ import argparse import logging import sys from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Run a learning rate tuning sweep." ) parser.add_argument( "--config", type=str, required=True, help="Path to the training config file.", ) parser.add_argument( "--start-lr", type=float, required=True, help="Starting learning rate for the sweep.", ) parser.add_argument( "--num-steps", type=int, required=True, help="Number of training steps to evaluate.", ) parser.add_argument( "--multiplier", type=float, default=1.01, help="Multiplier applied to the learning rate after each step.", ) parser.add_argument( "--warmup-steps", type=int, default=0, help=( "Optional number of warmup steps to run at a fixed learning rate before " "the exponential sweep." ), ) parser.add_argument( "--warmup-lr", type=float, help=( "Learning rate to use during warmup steps. Required when --warmup-steps > 0." ), ) parser.add_argument( "--csv-output", type=str, help=( "Optional path to write CSV results. Columns: lr, train_loss[, val_loss]." ), ) parser.add_argument( "--plot-output", type=str, help="Optional path to save a matplotlib plot of the sweep.", ) parser.add_argument( "--num-test-batches", type=int, default=0, help=( "When > 0, also compute and report validation loss on this many fixed batches " "(averaged each step). Default 0 (training loss only)." ), ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging(logging.INFO) parser = _build_parser() args = parser.parse_args(argv) # Lazy import to keep --help responsive and avoid heavy deps unless needed. from lczero_training.training.tune_lr import tune_lr tune_lr( config_filename=args.config, start_lr=args.start_lr, num_steps=args.num_steps, multiplier=args.multiplier, warmup_steps=args.warmup_steps, warmup_lr=args.warmup_lr, csv_output=args.csv_output, plot_output=args.plot_output, num_test_batches=args.num_test_batches, ) return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/commands/weights_tool.py ================================================ """CLI command for manipulating Lc0 neural network weights.""" import argparse import sys import numpy as np from lczero_training.commands import configure_root_logging def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Manipulate Lc0 neural network weights." ) parser.add_argument( "--expr", type=str, help=( "Python expression to execute " "(e.g., \"output = weights('A.pb') * 0.5\")" ), ) parser.add_argument( "script", nargs="?", help="Path to Python script file (if --expr not provided)", ) parser.add_argument( "--input", type=str, action="append", help="Pre-load input as NAME=PATH (e.g., --input A=net_A.pb.gz)", ) parser.add_argument( "--output", type=str, help='Default output path if "output" variable is set', ) parser.add_argument( "--encoding", type=str, default="FLOAT16", choices=["LINEAR16", "FLOAT16", "BFLOAT16"], help="Output encoding format (default: FLOAT16)", ) return parser def main(argv: list[str] | None = None) -> int: configure_root_logging() parser = _build_parser() args = parser.parse_args(argv) # Lazy import to avoid heavy dependencies on --help. from lczero_training.tools.weights_tool import load_weights, save_weights from proto import net_pb2 # Build execution environment. env = { "np": np, "weights": load_weights, "save": save_weights, "lc0": net_pb2, } # Pre-load inputs. if args.input: for input_spec in args.input: if "=" not in input_spec: print( f"Error: Invalid input spec '{input_spec}'. " "Expected format: NAME=PATH", file=sys.stderr, ) return 1 name, path = input_spec.split("=", 1) env[name] = load_weights(path) # Determine script source: --expr, file, or stdin. if args.expr: script = args.expr elif args.script: with open(args.script) as f: script = f.read() else: if sys.stdin.isatty(): print( "Error: No script provided. Use --expr, provide script file, " "or pipe to stdin.", file=sys.stderr, ) return 1 script = sys.stdin.read() # Execute script. try: exec(script, env) except Exception as e: print(f"Error executing script: {e}", file=sys.stderr) return 1 # Auto-save if 'output' variable is set. if "output" in env and args.output: from lczero_training.tools.weight_wrappers import NetWrapper output = env["output"] if isinstance(output, NetWrapper): save_weights(output, args.output, args.encoding) print(f"Saved result to {args.output}") return 0 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: src/lczero_training/convert/__init__.py ================================================ """Convert package for Leela Chess Zero training.""" ================================================ FILE: src/lczero_training/convert/jax_to_leela.py ================================================ import dataclasses import logging from typing import Optional, cast import numpy as np from flax import nnx from lczero_training.convert.leela_pytree_visitor import ( LeelaPytreeWeightsVisitor, ) from proto import net_pb2 logger = logging.getLogger(__name__) _EMBEDDING_PLANE_TO_SCALE = 109 _EMBEDDING_SCALE = 99.0 class JaxToLeela(LeelaPytreeWeightsVisitor): def embedding_block( self, nnx_dict: nnx.State, weights: net_pb2.Weights ) -> None: embedding_kernel = cast(nnx.Param, nnx_dict["embedding"]["kernel"]) original_values = embedding_kernel.value arr = np.asarray(original_values).copy() arr[_EMBEDDING_PLANE_TO_SCALE] /= _EMBEDDING_SCALE embedding_kernel.value = arr try: super().embedding_block(nnx_dict=nnx_dict, weights=weights) finally: embedding_kernel.value = original_values def tensor( self, param: nnx.Param, leela: net_pb2.Weights.Layer, ) -> None: weights = np.asarray(param.value, dtype=np.float32).T.flatten() min_val, max_val = np.min(weights), np.max(weights) range_val = max_val - min_val # Normalize to [0, 1], handling the case where all weights are equal. normalized = np.where( range_val > 1e-8, (weights - min_val) / range_val, 0.5 ) # Scale to uint16 and convert to bytes. quantized = np.round(normalized * 65535.0).astype(np.uint16) leela.params = quantized.tobytes() leela.min_val = float(min_val) leela.max_val = float(max_val) assert len(leela.params) // 2 == weights.size def encoder_tower( self, nnx_dict: nnx.State, weights: net_pb2.Weights ) -> None: for i in range(len(nnx_dict["encoders"]["layers"])): weights.encoder.append(weights.EncoderLayer()) return super().encoder_tower(nnx_dict=nnx_dict, weights=weights) @dataclasses.dataclass class LeelaExportOptions: min_version: str num_heads: int license: Optional[str] training_steps: Optional[int] = None def jax_to_leela( jax_weights: nnx.State, export_options: LeelaExportOptions ) -> net_pb2.Net: lc0_weights = net_pb2.Net() lc0_weights.magic = 0x1C0 if export_options.license: lc0_weights.license = export_options.license ( lc0_weights.min_version.major, lc0_weights.min_version.minor, lc0_weights.min_version.patch, ) = _split_version(export_options.min_version) lc0_weights.format.CopyFrom(_make_format()) if export_options.training_steps is not None: lc0_weights.training_params.training_steps = ( export_options.training_steps ) visitor = JaxToLeela(jax_weights, lc0_weights) lc0_weights.weights.headcount = export_options.num_heads visitor.run() return lc0_weights def _split_version(version_str: str) -> tuple[int, int, int]: """Splits a version string like "v12.34.56" into (12, 34, 56).""" parts = (version_str.lstrip("v").split(".") + ["0", "0"])[:3] return cast(tuple[int, int, int], tuple(map(int, parts))) def _make_format() -> net_pb2.Format: fmt = net_pb2.Format() fmt.weights_encoding = fmt.LINEAR16 netfmt = fmt.network_format netfmt.input = netfmt.INPUT_CLASSICAL_112_PLANE netfmt.output = netfmt.OUTPUT_WDL netfmt.network = netfmt.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT netfmt.policy = netfmt.POLICY_ATTENTION netfmt.value = netfmt.VALUE_WDL netfmt.moves_left = netfmt.MOVES_LEFT_V1 netfmt.default_activation = netfmt.DEFAULT_ACTIVATION_MISH netfmt.smolgen_activation = netfmt.ACTIVATION_SWISH netfmt.ffn_activation = netfmt.ACTIVATION_DEFAULT netfmt.input_embedding = netfmt.INPUT_EMBEDDING_PE_DENSE return fmt ================================================ FILE: src/lczero_training/convert/leela_pytree_visitor.py ================================================ import math from typing import Any, Optional from flax import nnx from proto import net_pb2 class LeelaPytreeWeightsVisitor: def __init__(self, nnx_state: nnx.State, leela_net: net_pb2.Net) -> None: self.leela_net = leela_net self.nnx_state = nnx_state def run(self) -> None: state = self.nnx_state weights = self.leela_net.weights self.embedding_block(state["embedding"], weights) self.encoder_tower(state["encoders"], weights) self.policy_heads(state, weights.policy_heads) for head_name in ["winner", "q", "st"]: if head_name in state["value_heads"]: self.value_head( state["value_heads"][head_name], getattr(weights.value_heads, head_name), ) for head_name in ["main"]: assert head_name in state["movesleft_heads"], ( f"movesleft head {head_name} missing in state" ) self.movesleft_head(state["movesleft_heads"][head_name], weights) def embedding_block( self, nnx_dict: nnx.State, weights: net_pb2.Weights ) -> None: self.matmul( nnx_dict["preprocess"], weights.ip_emb_preproc_w, weights.ip_emb_preproc_b, ) self.matmul( nnx_dict["embedding"], weights.ip_emb_w, weights.ip_emb_b, ) self.layernorm( nnx_dict["norm"], weights.ip_emb_ln_gammas, weights.ip_emb_ln_betas, ) self.tensor( nnx_dict["ma_gating"]["mult_gate"]["gate"], weights.ip_mult_gate ) self.tensor( nnx_dict["ma_gating"]["add_gate"]["gate"], weights.ip_add_gate ) self.ffn(nnx_dict["ffn"], weights.ip_emb_ffn) self.layernorm( nnx_dict["out_norm"], weights.ip_emb_ffn_ln_gammas, weights.ip_emb_ffn_ln_betas, ) def encoder_tower( self, nnx_dict: nnx.State, weights: net_pb2.Weights ) -> None: # Shared layer is stored at the point of the first usage. self.matmul( nnx_dict["encoders"]["layers"][0]["mha"]["smolgen"][ "weight_gen_dense" ], weights.smolgen_w, None, ) # assert len(nnx_dict["encoders"]["layers"]) == len(weights.encoder) for i in range(len(nnx_dict["encoders"]["layers"])): self.encoder_block( nnx_dict["encoders"]["layers"][i], weights.encoder[i] ) def encoder_block( self, nnx_dict: nnx.State, weights: net_pb2.Weights.EncoderLayer ) -> None: self.mha(nnx_dict["mha"], weights.mha) self.layernorm(nnx_dict["ln1"], weights.ln1_gammas, weights.ln1_betas) self.ffn(nnx_dict["ffn"], weights.ffn) self.layernorm(nnx_dict["ln2"], weights.ln2_gammas, weights.ln2_betas) def mha(self, nnx_dict: nnx.State, weights: net_pb2.Weights.MHA) -> None: self.matmul(nnx_dict["q"], weights.q_w, weights.q_b) self.matmul(nnx_dict["k"], weights.k_w, weights.k_b) self.matmul(nnx_dict["v"], weights.v_w, weights.v_b) self.smolgen(nnx_dict["smolgen"], weights.smolgen) self.matmul(nnx_dict["output_dense"], weights.dense_w, weights.dense_b) def smolgen( self, nnx_dict: nnx.State, weights: net_pb2.Weights.Smolgen ) -> None: self.matmul(nnx_dict["compress"], weights.compress, None) self.matmul(nnx_dict["dense1"], weights.dense1_w, weights.dense1_b) self.layernorm(nnx_dict["ln1"], weights.ln1_gammas, weights.ln1_betas) self.matmul(nnx_dict["dense2"], weights.dense2_w, weights.dense2_b) self.layernorm(nnx_dict["ln2"], weights.ln2_gammas, weights.ln2_betas) def layernorm( self, nnx_dict: nnx.State, scales: net_pb2.Weights.Layer, biases: net_pb2.Weights.Layer, ) -> None: self.tensor(nnx_dict["scale"], scales) self.tensor(nnx_dict["bias"], biases) def policy_heads( self, nnx_dict: nnx.State, weights: net_pb2.Weights.PolicyHeads ) -> None: if "policy_embedding_shared" in nnx_dict: self.matmul( nnx_dict["policy_embedding_shared"], weights.ip_pol_w, weights.ip_pol_b, ) policy_heads_dict = nnx_dict["policy_heads"] for head_name in ["vanilla", "optimistic_st", "soft", "opponent"]: if head_name in policy_heads_dict: self.policy_head( policy_heads_dict[head_name], getattr(weights, head_name) ) def policy_head( self, nnx_dict: nnx.State, weights: net_pb2.Weights.PolicyHead ) -> None: if "tokens" in nnx_dict: self.matmul(nnx_dict["tokens"], weights.ip_pol_w, weights.ip_pol_b) self.matmul(nnx_dict["q"], weights.ip2_pol_w, weights.ip2_pol_b) self.matmul(nnx_dict["k"], weights.ip3_pol_w, weights.ip3_pol_b) self.matmul(nnx_dict["promotion_dense"], weights.ip4_pol_w, None) def value_head( self, nnx_dict: nnx.State, weights: net_pb2.Weights.ValueHead ) -> None: self.matmul(nnx_dict["embed"], weights.ip_val_w, weights.ip_val_b) self.matmul(nnx_dict["dense1"], weights.ip1_val_w, weights.ip1_val_b) self.matmul(nnx_dict["wdl"], weights.ip2_val_w, weights.ip2_val_b) if "error" in nnx_dict: self.matmul( nnx_dict["error"], weights.ip_val_err_w, weights.ip_val_err_b ) if "categorical" in nnx_dict: self.matmul( nnx_dict["categorical"], weights.ip_val_cat_w, weights.ip_val_cat_b, ) def movesleft_head( self, nnx_dict: nnx.State, weights: net_pb2.Weights ) -> None: self.matmul(nnx_dict["embed"], weights.ip_mov_w, weights.ip_mov_b) self.matmul(nnx_dict["dense1"], weights.ip1_mov_w, weights.ip1_mov_b) self.matmul(nnx_dict["out"], weights.ip2_mov_w, weights.ip2_mov_b) def ffn(self, nnx_dict: nnx.State, ffn: net_pb2.Weights.FFN) -> None: self.matmul(nnx_dict["linear1"], ffn.dense1_w, ffn.dense1_b) self.matmul(nnx_dict["linear2"], ffn.dense2_w, ffn.dense2_b) def matmul( self, nnx_dict: nnx.State, weights: net_pb2.Weights.Layer, biases: Optional[net_pb2.Weights.Layer], ) -> None: self.tensor(nnx_dict["kernel"], weights) if biases: self.tensor(nnx_dict["bias"], biases) else: assert "bias" not in nnx_dict def tensor( self, param: Any, leela: net_pb2.Weights.Layer, ) -> None: print( param.shape, len(leela.params) // 2, math.prod(param.shape), ) assert len(leela.params) // 2 == math.prod(param.shape) assert len(leela.params) != 0 ================================================ FILE: src/lczero_training/convert/leela_to_jax.py ================================================ import dataclasses import gzip import logging import math from typing import Optional, cast import jax.numpy as jnp from flax import nnx, serialization from lczero_training.model.model import LczeroModel from proto import hlo_pb2, net_pb2 from .jax_to_leela import LeelaExportOptions, jax_to_leela from .leela_pytree_visitor import LeelaPytreeWeightsVisitor from .leela_to_modelconfig import leela_to_modelconfig logger = logging.getLogger(__name__) _EMBEDDING_PLANE_TO_SCALE = 109 _EMBEDDING_SCALE = 99.0 @dataclasses.dataclass class LeelaImportOptions: weights_dtype: hlo_pb2.XlaShapeProto.Type compute_dtype: hlo_pb2.XlaShapeProto.Type def fix_older_weights_file(file: net_pb2.Net) -> None: nf = net_pb2.NetworkFormat has_network_format = file.format.HasField("network_format") network_format = ( file.format.network_format.network if has_network_format else None ) net = file.format.network_format if not has_network_format: # Older protobufs don't have format definition. net.input = nf.INPUT_CLASSICAL_112_PLANE net.output = nf.OUTPUT_CLASSICAL net.network = nf.NETWORK_CLASSICAL_WITH_HEADFORMAT net.value = nf.VALUE_CLASSICAL net.policy = nf.POLICY_CLASSICAL elif network_format == nf.NETWORK_CLASSICAL: # Populate policyFormat and valueFormat fields in old protobufs # without these fields. net.network = nf.NETWORK_CLASSICAL_WITH_HEADFORMAT net.value = nf.VALUE_CLASSICAL net.policy = nf.POLICY_CLASSICAL elif network_format == nf.NETWORK_SE: net.network = nf.NETWORK_SE_WITH_HEADFORMAT net.value = nf.VALUE_CLASSICAL net.policy = nf.POLICY_CLASSICAL elif ( network_format == nf.NETWORK_SE_WITH_HEADFORMAT and len(file.weights.encoder) > 0 ): # Attention body network made with old protobuf. net.network = nf.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT if file.weights.HasField("smolgen_w"): # Need to override activation defaults for smolgen. net.ffn_activation = nf.ACTIVATION_RELU_2 net.smolgen_activation = nf.ACTIVATION_SWISH elif network_format == nf.NETWORK_AB_LEGACY_WITH_MULTIHEADFORMAT: net.network = nf.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT if ( file.format.network_format.network == nf.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT ): weights = file.weights if weights.HasField("policy_heads") and weights.HasField("value_heads"): logger.info( "Weights file has multihead format, updating format flag" ) net.network = nf.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT net.input_embedding = nf.INPUT_EMBEDDING_PE_DENSE if not file.format.network_format.HasField("input_embedding"): net.input_embedding = nf.INPUT_EMBEDDING_PE_MAP class LeelaToJax(LeelaPytreeWeightsVisitor): def embedding_block( self, nnx_dict: nnx.State, weights: net_pb2.Weights ) -> None: super().embedding_block(nnx_dict=nnx_dict, weights=weights) embedding_kernel = cast(nnx.Param, nnx_dict["embedding"]["kernel"]) values = embedding_kernel.value scaled_values = values.at[_EMBEDDING_PLANE_TO_SCALE].set( values[_EMBEDDING_PLANE_TO_SCALE] * _EMBEDDING_SCALE ) embedding_kernel.value = scaled_values def tensor( self, param: nnx.Param, leela: net_pb2.Weights.Layer, ) -> None: assert len(leela.params) // 2 == math.prod(param.shape) assert len(leela.params) != 0 values = jnp.frombuffer(leela.params, dtype=jnp.uint16) values = values.astype(jnp.float32) alpha = values / 65535.0 values = alpha * leela.max_val + (1.0 - alpha) * leela.min_val values = values.astype(param.dtype) values = values.reshape(param.shape[::-1]).transpose() param.value = values def leela_to_jax( leela_net: net_pb2.Net, import_options: LeelaImportOptions ) -> nnx.State: config = leela_to_modelconfig( leela_net, import_options.weights_dtype, import_options.compute_dtype, ) model = LczeroModel(config=config, rngs=nnx.Rngs(params=42)) state = nnx.state(model) visitor = LeelaToJax(state, leela_net) visitor.run() return state def leela_to_jax_files( input_path: str, weights_dtype: str, compute_dtype: str, output_modelconfig: Optional[str], output_serialized_jax: Optional[str], output_leela_verification: Optional[str], print_modelconfig: bool = False, ) -> None: lc0_weights = net_pb2.Net() with gzip.open(input_path, "rb") as f: contents = f.read() assert isinstance(contents, bytes) lc0_weights.ParseFromString(contents) fix_older_weights_file(lc0_weights) import_options = LeelaImportOptions( weights_dtype=getattr(hlo_pb2.XlaShapeProto, weights_dtype), compute_dtype=getattr(hlo_pb2.XlaShapeProto, compute_dtype), ) config = leela_to_modelconfig( lc0_weights, import_options.weights_dtype, import_options.compute_dtype, ) if print_modelconfig: print(config) if output_modelconfig: with open(output_modelconfig, "w") as f: f.write(str(config)) if output_serialized_jax is None and output_leela_verification is None: return state = leela_to_jax(lc0_weights, import_options) if output_serialized_jax: with open(output_serialized_jax, "wb") as f: f.write(serialization.to_bytes(state)) if output_leela_verification: min_version = ( f"v{lc0_weights.min_version.major}." f"{lc0_weights.min_version.minor}." f"{lc0_weights.min_version.patch}" ) license_str = ( lc0_weights.license if lc0_weights.HasField("license") else None ) export_options = LeelaExportOptions( min_version=min_version, num_heads=lc0_weights.weights.headcount, license=license_str, training_steps=lc0_weights.training_params.training_steps, ) verification_net = jax_to_leela(state, export_options) with gzip.open(output_leela_verification, "wb") as f: f.write(verification_net.SerializeToString()) ================================================ FILE: src/lczero_training/convert/leela_to_modelconfig.py ================================================ from proto import hlo_pb2, model_config_pb2, net_pb2 def _defaultactivation_to_activation( activation: net_pb2.NetworkFormat.DefaultActivation, ) -> net_pb2.NetworkFormat.ActivationFunction: return { net_pb2.NetworkFormat.DEFAULT_ACTIVATION_RELU: net_pb2.NetworkFormat.ACTIVATION_RELU, net_pb2.NetworkFormat.DEFAULT_ACTIVATION_MISH: net_pb2.NetworkFormat.ACTIVATION_MISH, }[activation] def leela_to_modelconfig( leela_net: net_pb2.Net, weights_dtype: hlo_pb2.XlaShapeProto.Type, compute_dtype: hlo_pb2.XlaShapeProto.Type, ) -> model_config_pb2.ModelConfig: assert weights_dtype == hlo_pb2.XlaShapeProto.F32, ( "Only float32 weights are supported." ) assert leela_net.format.weights_encoding == net_pb2.Format.LINEAR16 leela_net_format = leela_net.format.network_format model_config = model_config_pb2.ModelConfig() model_config.defaults.compute_dtype = compute_dtype model_config.defaults.activation = _defaultactivation_to_activation( leela_net_format.default_activation ) model_config.defaults.ffn_activation = ( leela_net_format.ffn_activation or model_config.defaults.activation ) assert ( leela_net_format.input_embedding == net_pb2.NetworkFormat.INPUT_EMBEDDING_PE_DENSE ), "Only dense positional embedding is supported, got {}".format( net_pb2.NetworkFormat.InputEmbeddingFormat.Name( leela_net_format.input_embedding ) ) assert leela_net_format.policy == net_pb2.NetworkFormat.POLICY_ATTENTION, ( "Only attention policy is supported, got {}".format( net_pb2.NetworkFormat.PolicyFormat.Name(leela_net_format.policy) ) ) assert leela_net_format.value == net_pb2.NetworkFormat.VALUE_WDL, ( "Only WDL value is supported, got {}".format( net_pb2.NetworkFormat.ValueFormat.Name(leela_net_format.value) ) ) assert leela_net_format.moves_left == net_pb2.NetworkFormat.MOVES_LEFT_V1, ( "Only V1 moves left format is supported, got {}".format( net_pb2.NetworkFormat.MovesLeftFormat.Name( leela_net_format.moves_left ) ) ) def size(x: net_pb2.Weights.Layer) -> int: return len(x.params) // 2 assert ( leela_net_format.network == net_pb2.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT ) weights = leela_net.weights model_config.embedding.dense_size = size(weights.ip_emb_preproc_b) // 64 model_config.embedding.embedding_size = size(weights.ip_emb_b) assert size(weights.ip_mult_gate) > 0 assert size(weights.ip_add_gate) > 0 model_config.embedding.dff = size(weights.ip_emb_ffn.dense1_b) model_config.encoder.num_blocks = len(weights.encoder) assert model_config.encoder.num_blocks > 0 encoder = weights.encoder[0] model_config.encoder.d_model = size(encoder.mha.q_b) model_config.encoder.heads = weights.headcount model_config.encoder.dff = size(encoder.ffn.dense1_b) if weights.HasField("smolgen_w"): model_config.encoder.smolgen.activation = ( leela_net_format.smolgen_activation or model_config.defaults.activation ) model_config.encoder.smolgen.hidden_channels = ( size(encoder.mha.smolgen.compress) // model_config.embedding.embedding_size ) model_config.encoder.smolgen.gen_size = ( size(encoder.mha.smolgen.dense2_b) // weights.headcount ) model_config.encoder.smolgen.hidden_size = size( encoder.mha.smolgen.dense1_b ) if weights.policy_heads.HasField("ip_pol_w"): model_config.shared_policy_embedding_size = size( weights.policy_heads.ip_pol_b ) for head_name in ["vanilla", "optimistic_st", "soft", "opponent"]: if weights.policy_heads.HasField(head_name): head = getattr(weights.policy_heads, head_name) assert size(head.ip2_pol_b) > 0 assert not head.HasField("ip_pol_w") policy_head = model_config.policy_head.add() policy_head.name = head_name if not model_config.HasField("shared_policy_embedding_size"): policy_head.embedding_size = size(head.ip_pol_b) policy_head.d_model = size(head.ip2_pol_b) for head_name in ["winner", "q", "st"]: if weights.value_heads.HasField(head_name): head = getattr(weights.value_heads, head_name) assert size(head.ip_val_b) > 0 value_head = model_config.value_head.add() value_head.name = head_name value_head.num_channels = size(head.ip_val_b) if head.HasField("ip_val_err_w"): value_head.has_error_output = True if head.HasField("ip_val_cat_b"): value_head.num_categorical_buckets = size(head.ip_val_cat_b) movesleft_head = model_config.movesleft_head.add() movesleft_head.name = "main" movesleft_head.num_channels = size(weights.ip_mov_b) return model_config ================================================ FILE: src/lczero_training/daemon/__init__.py ================================================ # ABOUTME: Daemon package for training subprocess communication. # ABOUTME: Provides TrainingDaemon class for IPC via stdin/stdout. ================================================ FILE: src/lczero_training/daemon/daemon.py ================================================ import logging import signal import sys import threading import time import anyio import proto.training_metrics_pb2 as training_metrics_pb2 from .pipeline import TrainingPipeline from .protocol.communicator import Communicator from .protocol.messages import ( StartTrainingImmediatelyPayload, StartTrainingPayload, TrainingStatusPayload, ) class TrainingDaemon: _training_pipeline: TrainingPipeline | None = None _config_filepath: str | None = None _daemon_start_time: float def __init__(self, memory_profile_dir: str | None = None) -> None: self._memory_profile_dir = memory_profile_dir self._daemon_start_time = time.time() self._setup_logging() self._setup_signal_handling() self._communicator = Communicator(self, sys.stdin, sys.stdout) self._communicator_thread = threading.Thread( target=lambda: self._communicator.run(), daemon=True ) self._communicator_thread.start() self._async_thread = threading.Thread( target=lambda: anyio.run(self._metrics_main), daemon=True ) self._async_thread.start() self._signal_thread = threading.Thread( target=self._signal_handler_thread, daemon=True ) self._signal_thread.start() def _setup_logging(self) -> None: logging.basicConfig( level=logging.INFO, format=( "%(levelname).1s%(asctime)s.%(msecs)03d %(name)s " "%(filename)s:%(lineno)d] %(message)s" ), datefmt="%m%d %H:%M:%S", stream=sys.stderr, ) logging.info("TrainingDaemon starting up") def _setup_signal_handling(self) -> None: # Block SIGINT and SIGTERM on all threads signal.pthread_sigmask( signal.SIG_BLOCK, {signal.SIGINT, signal.SIGTERM} ) def _signal_handler_thread(self) -> None: # Wait for SIGINT or SIGTERM signum = signal.sigwait({signal.SIGINT, signal.SIGTERM}) self._shutdown(signum) def _shutdown(self, signum: int) -> None: logging.info(f"Received signal {signum}, shutting down...") if self._training_pipeline: self._training_pipeline.stop() async def _metrics_main(self) -> None: async with anyio.create_task_group() as tg: tg.start_soon(self._metrics_task) async def _metrics_task(self) -> None: while True: await anyio.sleep(1.1) dataloader_1_second = None dataloader_total = None dataloader_update_secs = None training_schedule_data = None data_loader = None if self._training_pipeline: data_loader = self._training_pipeline.get_data_loader() training_schedule_data = ( self._training_pipeline.get_training_schedule_data( self._daemon_start_time ) ) if data_loader is not None: stats_1_second_bytes, _ = data_loader.get_bucket_metrics( 0, False ) # k1Second = 0 stats_total_bytes, dataloader_update_secs = ( data_loader.get_aggregate_ending_now(float("inf"), False) ) dataloader_1_second = ( training_metrics_pb2.DataLoaderMetricsProto() ) dataloader_1_second.ParseFromString(stats_1_second_bytes) dataloader_total = training_metrics_pb2.DataLoaderMetricsProto() dataloader_total.ParseFromString(stats_total_bytes) payload = TrainingStatusPayload( dataloader_update_secs=dataloader_update_secs, dataloader_1_second=dataloader_1_second, dataloader_total=dataloader_total, training_schedule=training_schedule_data, ) self._communicator.send(payload) def run(self) -> None: while self._config_filepath is None: logging.info("Waiting for training config...") time.sleep(1) logging.info("Config received. Starting training pipeline.") self._training_pipeline = TrainingPipeline( self._config_filepath, memory_profile_dir=self._memory_profile_dir, ) self._training_pipeline.run() def on_start_training(self, payload: StartTrainingPayload) -> None: self._config_filepath = payload.config_filepath def on_start_training_immediately( self, payload: StartTrainingImmediatelyPayload ) -> None: if not self._training_pipeline: logging.warning( "Received immediate training request before pipeline initialization." ) return self._training_pipeline.start_training_immediately() ================================================ FILE: src/lczero_training/daemon/metrics.py ================================================ """Metrics collection and logging for training daemon.""" import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Callable, Dict, Optional import jax import jax.numpy as jnp import numpy as np from flax import nnx from lczero_training._lczero_training import DataLoader from lczero_training.daemon.metrics_base import _Metric from lczero_training.daemon.rms_metrics import _RmsMetric from lczero_training.model.loss_function import LczeroLoss from lczero_training.training.state import JitTrainingState, TrainingSample from lczero_training.training.tensorboard import TensorboardLogger from lczero_training.training.training import StepHookData from proto.metrics_config_pb2 import MetricConfig, MetricsConfig logger = logging.getLogger(__name__) # Type alias for batch tuple returned by DataLoader # Using ... to allow variable length for compatibility with maybe_get_next BatchTuple = tuple[np.ndarray, ...] @dataclass class CachedBatch: """Cached batch data with the global step when it was last updated.""" batch: BatchTuple global_step: int def load_batch_from_npz(npz_filename: str) -> BatchTuple: """Load a batch from an NPZ file. Args: npz_filename: Path to the NPZ file. Returns: BatchTuple (tuple of inputs, probabilities, values arrays). Raises: ValueError: If the NPZ file doesn't contain exactly one batch. """ with np.load(npz_filename, allow_pickle=True) as npz_file: batches = npz_file["batches"] if batches.size != 1: raise ValueError( f"Expected 1 batch in npz '{npz_filename}', got {batches.size}" ) return batches[0] class _TrainingBatchMetric(_Metric): """Metric that logs training batch data.""" def __init__(self, config: MetricConfig, logger: TensorboardLogger): super().__init__(config, logger) if config.use_swa_model: raise ValueError( f"Metric '{config.name}': Cannot use SWA model for " "training_batch metrics" ) def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None: self.logger.log(hook_data.global_step, hook_data.metrics) class _EvaluatingMetric(_Metric, ABC): """Base class for metrics that evaluate loss on data.""" def __init__( self, config: MetricConfig, logger: TensorboardLogger, loss_fn: Optional[LczeroLoss], ): super().__init__(config, logger) if not loss_fn: raise ValueError(f"Metric '{config.name}': Loss function required") self.loss_fn = loss_fn @abstractmethod def get_batch(self) -> BatchTuple: """Get the batch data to evaluate.""" def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None: batch = self.get_batch() metrics = self._evaluate(batch, hook_data.jit_state, graphdef) self.logger.log(hook_data.global_step, metrics) def _evaluate( self, batch: BatchTuple, jit_state: JitTrainingState, graphdef: nnx.GraphDef, ) -> Dict[str, jax.Array]: model_state = ( jit_state.swa_state if self.config.use_swa_model else jit_state.model_state ) if model_state is None: raise RuntimeError("SWA state not available") batch_sample = TrainingSample( inputs=jnp.asarray(batch[0]), probabilities=jnp.asarray(batch[1]), values=jnp.asarray(batch[2]), ) return _make_eval_jit(graphdef, self.loss_fn)(model_state, batch_sample) _EvalJit = Callable[[nnx.State, TrainingSample], Dict[str, jax.Array]] _eval_jit_cache: dict[tuple[int, int], _EvalJit] = {} def _make_eval_jit(graphdef: nnx.GraphDef, loss_fn: LczeroLoss) -> _EvalJit: key = (id(graphdef), id(loss_fn)) if key not in _eval_jit_cache: @jax.jit def _eval( model_state: nnx.State, batch_sample: TrainingSample ) -> Dict[str, jax.Array]: model = nnx.merge(graphdef, model_state) loss_vfn = jax.vmap(loss_fn, in_axes=(None, 0), out_axes=0) per_sample_loss, unweighted = loss_vfn(model, batch_sample) return { "loss": jnp.mean(per_sample_loss), "unweighted_losses": jax.tree_util.tree_map( jnp.mean, unweighted ), } _eval_jit_cache[key] = _eval return _eval_jit_cache[key] def evaluate_batch( batch: BatchTuple, jit_state: JitTrainingState, graphdef: nnx.GraphDef, loss_fn: LczeroLoss, use_swa_model: bool = False, ) -> Dict[str, jax.Array]: """Evaluate loss function on a batch of data. Args: batch: BatchTuple (inputs, probabilities, values). jit_state: JIT training state containing model and optimizer state. graphdef: Graph definition of the model. loss_fn: Loss function to evaluate. use_swa_model: If True, use SWA model state instead of regular model. Returns: Dictionary of metrics with loss and unweighted losses. """ model_state = ( jit_state.swa_state if use_swa_model else jit_state.model_state ) if model_state is None: raise RuntimeError("SWA state not available") batch_sample = TrainingSample( inputs=jnp.asarray(batch[0]), probabilities=jnp.asarray(batch[1]), values=jnp.asarray(batch[2]), ) return _make_eval_jit(graphdef, loss_fn)(model_state, batch_sample) class _DataLoaderMetric(_EvaluatingMetric): """Metric that evaluates loss on dataloader output.""" def __init__( self, config: MetricConfig, logger: TensorboardLogger, loss_fn: Optional[LczeroLoss], data_loader: Optional[DataLoader], dataloader_name: str, cached_batches: Dict[str, CachedBatch], ): super().__init__(config, logger, loss_fn) if not data_loader: raise ValueError(f"Metric '{config.name}': DataLoader required") self.data_loader = data_loader self.dataloader_name = dataloader_name self.cached_batches = cached_batches def get_batch(self) -> BatchTuple: return self.cached_batches[self.dataloader_name].batch def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None: """Update cache if needed and log the metric.""" cached = self.cached_batches.get(self.dataloader_name) if cached is None or cached.global_step != hook_data.global_step: batch = self.data_loader.maybe_get_next(self.dataloader_name) if batch is not None: self.cached_batches[self.dataloader_name] = CachedBatch( batch, hook_data.global_step ) elif cached is None: raise RuntimeError( f"No data for metric '{self.config.name}' " f"from dataloader '{self.dataloader_name}'" ) super().log(hook_data, graphdef) class _NpzMetric(_EvaluatingMetric): """Metric that evaluates loss on pre-loaded NPZ data.""" def __init__( self, config: MetricConfig, logger: TensorboardLogger, loss_fn: Optional[LczeroLoss], npz_filename: str, cached_batches: Dict[str, CachedBatch], ): super().__init__(config, logger, loss_fn) self.npz_filename = npz_filename self.cached_batches = cached_batches def get_batch(self) -> BatchTuple: return self.cached_batches[self.npz_filename].batch def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None: """Load NPZ data if needed and log the metric.""" if self.npz_filename not in self.cached_batches: batch = load_batch_from_npz(self.npz_filename) self.cached_batches[self.npz_filename] = CachedBatch( batch, hook_data.global_step ) super().log(hook_data, graphdef) class Metrics: """Manages metrics collection and logging for training.""" def __init__( self, config: MetricsConfig, loss_fn: Optional[LczeroLoss] = None, data_loader: Optional[DataLoader] = None, ): self._metrics: Dict[str, _Metric] = {} self._cached_batches: Dict[str, CachedBatch] = {} for mc in config.metric: tb_logger = TensorboardLogger( os.path.join(config.tensorboard_path, mc.name) ) metric: _Metric if mc.HasField("training_batch"): metric = _TrainingBatchMetric(mc, tb_logger) elif mc.HasField("dataloader_output"): metric = _DataLoaderMetric( mc, tb_logger, loss_fn, data_loader, mc.dataloader_output, self._cached_batches, ) elif mc.HasField("npz_filename"): metric = _NpzMetric( mc, tb_logger, loss_fn, mc.npz_filename, self._cached_batches, ) elif mc.HasField("weights"): if mc.weights.rms: metric = _RmsMetric(mc, tb_logger) else: raise ValueError( f"Metric '{mc.name}': No weight metric type specified" ) else: raise ValueError(f"Metric '{mc.name}' has no sample source") self._metrics[mc.name] = metric def on_step(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None: """Process metrics for the current step.""" for metric in self._metrics.values(): if metric.should_log( hook_data.global_step, hook_data.local_step, hook_data.steps_per_epoch, ): metric.log(hook_data, graphdef) def close(self) -> None: """Close all TensorBoard loggers.""" for metric in self._metrics.values(): metric.logger.close() ================================================ FILE: src/lczero_training/daemon/metrics_base.py ================================================ """Base classes for metrics.""" from abc import ABC, abstractmethod from flax import nnx from lczero_training.training.tensorboard import TensorboardLogger from lczero_training.training.training import StepHookData from proto.metrics_config_pb2 import MetricConfig class _Metric(ABC): """Base class for individual metric tracking.""" def __init__(self, config: MetricConfig, logger: TensorboardLogger): self.config = config self.logger = logger def should_log( self, global_step: int, local_step: int, steps_per_epoch: int ) -> bool: """Check if it's time to log this metric.""" if self.config.after_epoch and local_step + 1 == steps_per_epoch: return True step = global_step if self.config.use_global_steps else local_step return (step + 1) % self.config.period == 0 @abstractmethod def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None: """Log the metric for the current step.""" ================================================ FILE: src/lczero_training/daemon/pipeline.py ================================================ import dataclasses import datetime import gzip import logging import os import threading import time from pathlib import Path from typing import cast import jax import orbax.checkpoint as ocp import requests from dotenv import load_dotenv from flax import nnx from google.protobuf import text_format from lczero_training._lczero_training import DataLoader from lczero_training.convert.jax_to_leela import ( LeelaExportOptions, jax_to_leela, ) from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel from lczero_training.training.lr_schedule import make_lr_schedule from lczero_training.training.optimizer import make_gradient_transformation from lczero_training.training.state import JitTrainingState, TrainingState from lczero_training.training.training import ( StepHookData, Training, from_dataloader, ) from proto.data_loader_config_pb2 import DataLoaderConfig from proto.root_config_pb2 import RootConfig from proto.stage_control_pb2 import StageControlRequest, StageControlResponse from proto.training_config_pb2 import ScheduleConfig from .metrics import Metrics from .protocol.messages import TrainingScheduleData, TrainingStage logger = logging.getLogger(__name__) def _read_config_file(config_filepath: str) -> RootConfig: config_path = Path(config_filepath) config_text = config_path.read_text() root_config = RootConfig() text_format.Parse(config_text, root_config) return root_config def _make_dataloader(config: DataLoaderConfig) -> DataLoader: config_bytes = config.SerializeToString() return DataLoader(config_bytes) def _configure_file_logging(config: RootConfig) -> None: """Configure file logging if log_filename is specified in config.""" if config.HasField("log_filename"): file_handler = logging.FileHandler(config.log_filename) file_handler.setFormatter( logging.Formatter( "%(levelname).1s%(asctime)s.%(msecs)03d %(name)s " "%(filename)s:%(lineno)d] %(message)s", datefmt="%m%d %H:%M:%S", ) ) logging.getLogger().addHandler(file_handler) logger.info(f"Added file logging to {config.log_filename}") def _log_jax_system_info() -> None: """Log JAX system information including devices and backend details.""" devices = jax.devices() local_devices = jax.local_devices() device_counts: dict[str, int] = {} for device in devices: device_type = device.device_kind device_counts[device_type] = device_counts.get(device_type, 0) + 1 logger.info(f"JAX Backend: {jax.default_backend()}") logger.info( f"JAX Devices: {len(devices)} total, {len(local_devices)} local" ) for device_type, count in device_counts.items(): logger.info(f" {device_type}: {count}") for i, device in enumerate(local_devices): logger.info(f" Local device {i}: {device}") @dataclasses.dataclass class _TrainingCycleState: start_time: float = dataclasses.field(default_factory=time.time) current_stage: TrainingStage = TrainingStage.WAITING_FOR_DATA completed_epochs: int = 0 current_cycle_start_time: float = dataclasses.field( default_factory=time.time ) current_training_start_time: float | None = None previous_training_duration: float = 0.0 previous_cycle_duration: float = 0.0 chunks_at_training_start: int = 0 class TrainingPipeline: _data_loader: DataLoader _schedule: ScheduleConfig _chunks_to_wait: int _model: LczeroModel _checkpoint_mgr: ocp.CheckpointManager _training_state: TrainingState _cycle_state: _TrainingCycleState _metrics: Metrics | None def __init__( self, config_filepath: str, memory_profile_dir: str | None = None, ) -> None: self._memory_profile_dir = memory_profile_dir logger.info(f"Loading config from {config_filepath}") self._config = self._load_config(config_filepath) _configure_file_logging(self._config) self._schedule = self._config.training.schedule self._chunks_per_network = self._schedule.chunks_per_network self._num_steps_per_epoch = self._schedule.steps_per_network self._chunks_to_wait = self._chunks_per_network self._cycle_state = _TrainingCycleState() self._force_training_event = threading.Event() self._metrics = None logger.info("Creating empty model") self._model = LczeroModel(self._config.model, rngs=nnx.Rngs(params=42)) self._graphdef = nnx.graphdef(self._model) logger.info( f"Creating checkpoint manager at {self._config.training.checkpoint.path}" ) self._checkpoint_mgr = ocp.CheckpointManager( self._config.training.checkpoint.path, options=ocp.CheckpointManagerOptions( max_to_keep=self._config.training.checkpoint.max_to_keep or None, ), ) logger.info("Restoring checkpoint") optimizer_config = self._config.training.optimizer max_grad_norm = getattr(self._config.training, "max_grad_norm", 0.0) self._lr_schedule = make_lr_schedule(self._config.training.lr_schedule) optimizer_tx = make_gradient_transformation( optimizer_config, max_grad_norm=max_grad_norm, lr_schedule=self._lr_schedule, ) model_state = nnx.state(self._model) jit_state = JitTrainingState( step=0, model_state=model_state, opt_state=optimizer_tx.init(model_state), swa_state=model_state, num_averages=0.0, ) empty_state = TrainingState( jit_state=jit_state, num_heads=self._config.model.encoder.heads, ) self._training_state = cast( TrainingState, self._checkpoint_mgr.restore( step=None, args=ocp.args.PyTreeRestore( item=empty_state, ), ), ) logger.info("Creating training session") loss_fn = LczeroLoss(config=self._config.training.losses) self._training = Training( optimizer_tx=make_gradient_transformation( self._config.training.optimizer, max_grad_norm=max_grad_norm, lr_schedule=self._lr_schedule, ), graphdef=nnx.graphdef(self._model), loss_fn=loss_fn, swa_config=( self._config.training.swa if self._config.training.HasField("swa") else None ), ) logger.info("Creating data loader") self._data_loader = _make_dataloader(self._config.data_loader) self._set_chunk_anchor(self._training_state.last_chunk_source) # Create metrics if configured. if self._config.HasField("metrics"): logger.info("Creating metrics") self._metrics = Metrics( config=self._config.metrics, loss_fn=loss_fn, data_loader=self._data_loader, ) else: logger.info("No metrics configured") _log_jax_system_info() def start_training_immediately(self) -> None: """Request the next training cycle to start without waiting for chunks.""" logger.info("Received request to start training immediately.") self._force_training_event.set() def run(self) -> None: logging.info("Starting DataLoader") self._data_loader.start() while True: self._wait_for_chunks() new_anchor, used_chunks = self._reset_chunk_anchor() logging.info(f"{new_anchor=} {used_chunks=}") self._training_state = self._training_state.replace( last_chunk_source=new_anchor ) self._chunks_to_wait = max( self._chunks_to_wait + self._chunks_per_network - used_chunks, self._chunks_per_network // 2, ) self._train_one_network() self._save_checkpoint() network_bytes = self._export_network() if network_bytes: self._save_network(network_bytes) self._upload_network(network_bytes) def _export_network(self) -> bytes | None: if ( not self._config.export.destination_filename and not self._config.export.HasField("upload_training_run") ): return None logging.info("Exporting network") options = LeelaExportOptions( min_version="0.31", num_heads=self._training_state.num_heads, license=None, training_steps=self._training_state.jit_state.step, ) export_state = ( self._training_state.jit_state.swa_state if self._config.export.export_swa_model else self._training_state.jit_state.model_state ) assert isinstance(export_state, nnx.State) net = jax_to_leela(jax_weights=export_state, export_options=options) return gzip.compress(net.SerializeToString()) def _save_network(self, network_bytes: bytes) -> None: date_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") step = self._training_state.jit_state.step for destination_template in self._config.export.destination_filename: destination = destination_template.format( datetime=date_str, step=step ) logging.info(f"Writing network to {destination}") os.makedirs(os.path.dirname(destination), exist_ok=True) with open(destination, "wb") as f: f.write(network_bytes) logging.info(f"Finished writing network to {destination}") def _upload_network(self, network_bytes: bytes) -> None: if not self._config.export.HasField("upload_training_run"): return load_dotenv() upload_pwd = os.getenv("UPLOAD_PWD") if not upload_pwd: logging.error( "UPLOAD_PWD not found in environment variables, skipping upload." ) return try: state = cast(nnx.State, nnx.state(self._model)) layers = len(state["encoders"]["encoders"]["layers"]) filters = state["embedding"]["embedding"]["bias"].shape[0] training_id = self._config.export.upload_training_run logging.info( f"Uploading network to training website (ID: {training_id}, " f"layers: {layers}, filters: {filters})" ) data = { "pwd": upload_pwd, "training_id": training_id, "layers": layers, "filters": filters, } response = requests.post( "http://api.lczero.org/upload_network", files={"file": network_bytes}, data=data, ) response.raise_for_status() logging.info(f"Successfully uploaded network: {response.text}") except requests.exceptions.RequestException as e: logging.error(f"Failed to upload network: {e}") except (KeyError, AttributeError, IndexError) as e: logging.error(f"Failed to extract model metadata for upload: {e}") def _step_hook(self, hook_data: StepHookData) -> None: # Append current learning rate from schedule to metrics. hook_data.metrics["lr"] = self._lr_schedule(hook_data.global_step) if self._metrics is not None: self._metrics.on_step(hook_data, self._graphdef) def _train_one_network(self) -> None: logging.info("Training one network!") # Record training start self._cycle_state.current_training_start_time = time.time() self._cycle_state.current_stage = TrainingStage.TRAINING self._cycle_state.chunks_at_training_start = self._chunks_since_anchor() new_jit_state = self._training.run( jit_state=self._training_state.jit_state, datagen=from_dataloader(self._data_loader), num_steps=self._schedule.steps_per_network, step_hook=self._step_hook, memory_profile_dir=self._memory_profile_dir, ) self._training_state = self._training_state.replace( jit_state=new_jit_state ) # Record training end current_time = time.time() if self._cycle_state.current_training_start_time: self._cycle_state.previous_training_duration = ( current_time - self._cycle_state.current_training_start_time ) self._cycle_state.previous_cycle_duration = ( current_time - self._cycle_state.current_cycle_start_time ) self._cycle_state.completed_epochs += 1 self._cycle_state.current_training_start_time = None self._cycle_state.current_stage = TrainingStage.WAITING_FOR_DATA self._cycle_state.current_cycle_start_time = current_time logging.info("Done training") def _save_checkpoint(self) -> None: logging.info("Saving checkpoint") self._checkpoint_mgr.save( step=self._training_state.jit_state.step, args=ocp.args.PyTreeSave(item=self._training_state), ) logging.info("Checkpoint saved") def stop(self) -> None: self._data_loader.stop() if self._metrics is not None: self._metrics.close() def get_data_loader(self) -> DataLoader: return self._data_loader def _wait_for_chunks(self) -> None: current_chunks = self._chunks_since_anchor() logger.info( f"Waiting for {self._chunks_to_wait} chunks. " f"got {current_chunks} so far" ) while True: if self._force_training_event.is_set(): logger.info( "Force start requested; skipping remaining chunk wait." ) self._force_training_event.clear() self._chunks_to_wait = self._chunks_since_anchor() return if self._chunks_since_anchor() >= self._chunks_to_wait: logger.info("Done waiting for enough chunks") return time.sleep(1) def get_training_schedule_data( self, daemon_start_time: float ) -> TrainingScheduleData: """Return current training schedule data for TUI display.""" current_time = time.time() # Calculate current training time if currently training current_training_time = 0.0 if self._cycle_state.current_training_start_time is not None: current_training_time = ( current_time - self._cycle_state.current_training_start_time ) # Calculate current cycle time current_cycle_time = ( current_time - self._cycle_state.current_cycle_start_time ) # Calculate new chunks since training start new_chunks_since_training_start = max( 0, self._chunks_since_anchor() - self._cycle_state.chunks_at_training_start, ) return TrainingScheduleData( current_stage=self._cycle_state.current_stage, completed_epochs_since_start=self._cycle_state.completed_epochs, new_chunks_since_training_start=new_chunks_since_training_start, chunks_to_wait=self._chunks_to_wait, total_uptime_seconds=current_time - daemon_start_time, current_training_time_seconds=current_training_time, previous_training_time_seconds=self._cycle_state.previous_training_duration, current_cycle_time_seconds=current_cycle_time, previous_cycle_time_seconds=self._cycle_state.previous_cycle_duration, ) def _send_chunk_pool_control( self, request: StageControlRequest ) -> StageControlResponse | None: responses = self._data_loader.send_control_message(request) for _, response in responses: if response.HasField("chunk_pool_response"): return response return None def _reset_chunk_anchor(self) -> tuple[str, int]: request = StageControlRequest() request.chunk_pool_request.reset_chunk_anchor = True response = self._send_chunk_pool_control(request) if not response or not response.HasField("chunk_pool_response"): return "", 0 chunk_response = response.chunk_pool_response return chunk_response.chunk_anchor, chunk_response.chunks_since_anchor def _chunks_since_anchor(self) -> int: request = StageControlRequest() request.chunk_pool_request.SetInParent() response = self._send_chunk_pool_control(request) if not response or not response.HasField("chunk_pool_response"): return 0 return response.chunk_pool_response.chunks_since_anchor def _set_chunk_anchor(self, anchor: str) -> None: request = StageControlRequest() request.chunk_pool_request.set_chunk_anchor = anchor or "" self._send_chunk_pool_control(request) def _load_config(self, config_filepath: str) -> RootConfig: config_path = Path(config_filepath) config_text = config_path.read_text() root_config = RootConfig() text_format.Parse(config_text, root_config) return root_config ================================================ FILE: src/lczero_training/daemon/protocol/__init__.py ================================================ # ABOUTME: Protocol package for JSONL IPC communication between processes. # ABOUTME: Contains registry system, message definitions, and communicator class. ================================================ FILE: src/lczero_training/daemon/protocol/communicator.py ================================================ # ABOUTME: Core Communicator class for JSONL IPC between processes. # ABOUTME: Handles serialization/deserialization and message dispatch via stdin/stdout. import json import types from dataclasses import is_dataclass from enum import Enum from typing import Any, Optional, TextIO, Union, get_args, get_origin from anyio.streams.text import TextReceiveStream, TextSendStream from google.protobuf.json_format import MessageToDict, ParseDict from google.protobuf.message import Message from .registry import CLASS_TO_TYPE_MAP, TYPE_TO_CLASS_MAP def _to_serializable(obj: Any) -> Any: """Convert dataclass/protobuf objects to JSON-serializable dicts.""" if isinstance(obj, Message): return MessageToDict( obj, preserving_proto_field_name=True, use_integers_for_enums=True ) elif isinstance(obj, Enum): return obj.value elif is_dataclass(obj): return { f.name: _to_serializable(getattr(obj, f.name)) for f in obj.__dataclass_fields__.values() if getattr(obj, f.name) is not None } elif isinstance(obj, (list, tuple)): return [_to_serializable(item) for item in obj] elif isinstance(obj, dict): return {k: _to_serializable(v) for k, v in obj.items()} else: return obj def _unwrap_optional(t: Any) -> Any: """Extract T from T | None or Union[T, None].""" if isinstance(t, types.UnionType) or get_origin(t) is Union: args = [a for a in get_args(t) if a is not type(None)] return args[0] if len(args) == 1 else t return t def _is_protobuf(cls: type) -> bool: """Check if cls is a protobuf Message class.""" try: return isinstance(cls, type) and issubclass(cls, Message) except TypeError: return False def _from_serializable(cls: type, data: Any) -> Any: """Reconstruct dataclass/protobuf from dict.""" if _is_protobuf(cls): instance = cls() ParseDict(data, instance) return instance if not is_dataclass(cls): return data args = {} for field in cls.__dataclass_fields__.values(): if field.name not in data: continue value = data[field.name] field_type = _unwrap_optional(field.type) if get_origin(field_type) is list: item_type = get_args(field_type)[0] if is_dataclass(item_type) or _is_protobuf(item_type): value = [_from_serializable(item_type, item) for item in value] elif is_dataclass(field_type) or _is_protobuf(field_type): value = _from_serializable(field_type, value) elif isinstance(field_type, type) and issubclass(field_type, Enum): # Convert string value back to enum value = field_type(value) args[field.name] = value return cls(**args) class Communicator: def __init__( self, handler: Any, input_stream: TextIO, output_stream: TextIO ) -> None: """ Initializes the Communicator. Args: handler: An object with `on_` methods. input_stream: A file-like object to read incoming messages from (e.g., sys.stdin). output_stream: A file-like object to write outgoing messages to (e.g., sys.stdout). """ self.handler = handler self.input = input_stream self.output = output_stream def send(self, payload_instance: Any) -> None: """ Serializes and sends a payload object as a notification. The event type is automatically looked up from the registry. """ payload_cls = type(payload_instance) event_type = CLASS_TO_TYPE_MAP.get(payload_cls) if event_type is None: raise TypeError( f"Object of type {payload_cls.__name__} is not a registered payload." ) payload_dict = _to_serializable(payload_instance) message = {"type": event_type, "payload": payload_dict} self.output.write(json.dumps(message) + "\n") self.output.flush() def _dispatch(self, line: str) -> None: line = line.strip() if not line: return data = json.loads(line) event_type = data["type"] payload_dict = data["payload"] payload_cls = TYPE_TO_CLASS_MAP[event_type] payload_instance = _from_serializable(payload_cls, payload_dict) handler_method_name = f"on_{event_type}" handler_method = getattr(self.handler, handler_method_name) handler_method(payload_instance) def run(self) -> None: """ Starts the blocking listener loop. Reads from the input stream line-by-line, deserializes notifications, and dispatches them to the appropriate handler method. This method blocks until the input stream is closed. """ for line in self.input: self._dispatch(line) class AsyncCommunicator: def __init__( self, handler: Any, input_stream: TextReceiveStream, output_stream: TextSendStream, io_dump: Optional[TextIO] = None, ) -> None: """ Initializes the AsyncCommunicator. Args: handler: An object with async `on_` methods. input_stream: A TextReceiveStream to read incoming messages from. output_stream: A TextSendStream to write outgoing messages to. io_dump: Optional file to dump raw IO for debugging. """ self.handler = handler self.input_stream = input_stream self.output_stream = output_stream self._io_dump = io_dump self._buffer = "" async def send(self, payload_instance: Any) -> None: """ Serializes and sends a payload object as a notification. The event type is automatically looked up from the registry. """ payload_cls = type(payload_instance) event_type = CLASS_TO_TYPE_MAP.get(payload_cls) if event_type is None: raise TypeError( f"Object of type {payload_cls.__name__} is not a registered payload." ) payload_dict = _to_serializable(payload_instance) message = {"type": event_type, "payload": payload_dict} message_line = json.dumps(message) + "\n" if self._io_dump: self._io_dump.write(f"> {message_line}") await self.output_stream.send(message_line) async def _dispatch(self, line: str) -> None: line = line.strip() if not line: return data = json.loads(line) event_type = data["type"] payload_dict = data["payload"] payload_cls = TYPE_TO_CLASS_MAP[event_type] payload_instance = _from_serializable(payload_cls, payload_dict) handler_method_name = f"on_{event_type}" handler_method = getattr(self.handler, handler_method_name) await handler_method(payload_instance) async def run(self) -> None: """ Starts the async listener loop. Reads from the input stream line-by-line, deserializes notifications, and dispatches them to the appropriate async handler method. This method runs until the input stream is closed. """ async for chunk in self.input_stream: self._buffer += chunk while "\n" in self._buffer: line, self._buffer = self._buffer.split("\n", 1) if self._io_dump and line: self._io_dump.write(f"< {line}\n") await self._dispatch(line) if self._buffer: await self._dispatch(self._buffer) ================================================ FILE: src/lczero_training/daemon/protocol/messages.py ================================================ # ABOUTME: Payload dataclass definitions for JSONL IPC protocol messages. # ABOUTME: Defines minimal event types for training daemon communication. from dataclasses import dataclass from enum import Enum from typing import Optional import proto.training_metrics_pb2 as training_metrics_pb2 from .registry import register class TrainingStage(Enum): WAITING_FOR_DATA = "WAITING FOR DATA" TRAINING = "TRAINING" @dataclass class TrainingScheduleData: current_stage: TrainingStage completed_epochs_since_start: int new_chunks_since_training_start: int chunks_to_wait: int total_uptime_seconds: float current_training_time_seconds: float previous_training_time_seconds: float current_cycle_time_seconds: float previous_cycle_time_seconds: float # --- Notifications from UI (Parent) to Trainer (Child) --- @register("start_training") @dataclass class StartTrainingPayload: config_filepath: str @register("start_training_immediately") @dataclass class StartTrainingImmediatelyPayload: pass # --- Notifications from Trainer (Child) to UI (Parent) --- @register("training_status") @dataclass class TrainingStatusPayload: dataloader_update_secs: Optional[float] = None dataloader_1_second: Optional[ training_metrics_pb2.DataLoaderMetricsProto ] = None dataloader_total: Optional[training_metrics_pb2.DataLoaderMetricsProto] = ( None ) training_schedule: Optional[TrainingScheduleData] = None ================================================ FILE: src/lczero_training/daemon/protocol/registry.py ================================================ # ABOUTME: Registry system for mapping event type strings to payload dataclasses. # ABOUTME: Provides @register decorator and maintains bidirectional mapping dicts. import inspect from typing import Callable # These maps will be populated by the @register decorator TYPE_TO_CLASS_MAP = {} CLASS_TO_TYPE_MAP = {} def register(event_type: str) -> Callable[[type], type]: """A decorator to register a payload dataclass with its event type string.""" def decorator(cls: type) -> type: if not inspect.isclass(cls): raise TypeError( "The @register decorator can only be used on classes." ) if event_type in TYPE_TO_CLASS_MAP: raise ValueError( f"Event type '{event_type}' is already registered." ) if cls in CLASS_TO_TYPE_MAP: raise ValueError(f"Class '{cls.__name__}' is already registered.") TYPE_TO_CLASS_MAP[event_type] = cls CLASS_TO_TYPE_MAP[cls] = event_type return cls return decorator ================================================ FILE: src/lczero_training/daemon/rms_metrics.py ================================================ """RMS metrics for model parameters.""" from typing import Any, cast import jax import jax.numpy as jnp from flax import nnx from lczero_training.daemon.metrics_base import _Metric from lczero_training.model.encoder import EncoderBlock from lczero_training.model.model import LczeroModel from lczero_training.training.tensorboard import TensorboardLogger from lczero_training.training.training import StepHookData from proto.metrics_config_pb2 import MetricConfig @jax.jit def compute_rms(state_subtree: nnx.State) -> jax.Array: """Compute RMS of all parameters in a state subtree.""" leaves = jax.tree_util.tree_leaves(state_subtree) total_sq = sum(jnp.sum(jnp.square(p)) for p in leaves) total_n = sum(p.size for p in leaves) return jnp.sqrt(total_sq / total_n) def extract_attention_components(model: LczeroModel) -> dict[str, Any]: """Extract Q, K, V, output_dense, smolgen from all encoder layers. Args: model: LczeroModel instance. Returns: Dict with keys 'q', 'k', 'v', 'output_dense', optionally 'smolgen'. """ components: dict[str, Any] = { "q": {}, "k": {}, "v": {}, "output_dense": {}, } encoder_layers = cast(list[EncoderBlock], model.encoders.encoders.layers) for i, encoder_block in enumerate(encoder_layers): mha = encoder_block.mha components["q"][f"layer_{i}"] = nnx.state(mha.q) components["k"][f"layer_{i}"] = nnx.state(mha.k) components["v"][f"layer_{i}"] = nnx.state(mha.v) components["output_dense"][f"layer_{i}"] = nnx.state(mha.output_dense) if mha.smolgen is not None: if "smolgen" not in components: components["smolgen"] = {} components["smolgen"][f"layer_{i}"] = nnx.state(mha.smolgen) return components def collect_rms_metrics(model: LczeroModel) -> dict[str, Any]: """Collect all RMS metrics for the model. Args: model: LczeroModel instance. Returns: Nested dict with RMS values for different model components. """ model_state = nnx.state(model) metrics: dict[str, Any] = { "all_params": compute_rms(model_state), "embedding": compute_rms(nnx.state(model.embedding)), "encoder_body": compute_rms(nnx.state(model.encoders)), } # Attention components attn_components = extract_attention_components(model) metrics["attention"] = { name: compute_rms(component) for name, component in attn_components.items() } # Policy heads metrics["policy_heads"] = { name: compute_rms(nnx.state(head)) for name, head in model.policy_heads.items() } # Value heads metrics["value_heads"] = { name: compute_rms(nnx.state(head)) for name, head in model.value_heads.items() } # Movesleft heads metrics["movesleft_heads"] = { name: compute_rms(nnx.state(head)) for name, head in model.movesleft_heads.items() } return metrics class _RmsMetric(_Metric): """Metric that computes RMS of model parameters.""" def __init__(self, config: MetricConfig, logger: TensorboardLogger): super().__init__(config, logger) def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None: model_state = ( hook_data.jit_state.swa_state if self.config.use_swa_model else hook_data.jit_state.model_state ) model = nnx.merge(graphdef, model_state) metrics = collect_rms_metrics(model) self.logger.log(hook_data.global_step, metrics) ================================================ FILE: src/lczero_training/dataloader/__init__.py ================================================ from lczero_training._lczero_training import ( DataLoader, TensorBase, ) from proto.data_loader_config_pb2 import DataLoaderConfig __all__ = ["DataLoader", "make_dataloader", "TensorBase"] def make_dataloader(config: DataLoaderConfig) -> DataLoader: loader = DataLoader(config) loader.start() return loader ================================================ FILE: src/lczero_training/model/__init__.py ================================================ ================================================ FILE: src/lczero_training/model/embedding.py ================================================ import jax import jax.numpy as jnp from flax import nnx from proto import model_config_pb2 from .shared import Ffn from .utils import get_activation class Embedding(nnx.Module): """Computes embeddings for the input features.""" def __init__( self, *, input_channels: int, config: model_config_pb2.EmbeddingConfig, defaults: model_config_pb2.DefaultsConfig, deepnorm_alpha: float, deepnorm_beta: float, rngs: nnx.Rngs, ): self._input_channels = input_channels dense_size = config.dense_size embedding_size = config.embedding_size self.activation = defaults.activation assert dense_size > 0 self.preprocess = nnx.Linear( in_features=64 * 12, out_features=64 * dense_size, rngs=rngs, ) assert embedding_size > 0 self.embedding = nnx.Linear( in_features=input_channels + dense_size, out_features=embedding_size, rngs=rngs, ) self.norm = nnx.LayerNorm(embedding_size, epsilon=1e-3, rngs=rngs) self.ma_gating = MaGating(feature_shape=(64, embedding_size), rngs=rngs) self.deepnorm_alpha = deepnorm_alpha self.ffn = Ffn( in_features=embedding_size, hidden_features=config.dff, hidden_activation=defaults.ffn_activation, deepnorm_beta=deepnorm_beta, rngs=rngs, ) self.out_norm = nnx.LayerNorm(embedding_size, epsilon=1e-3, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: # Preprocess positional info and concatenate to input. pos_info = self.preprocess(x[..., :12].flatten()).reshape((64, -1)) x = jnp.concatenate([x, pos_info], axis=1) # Square embedding. x = self.embedding(x) x = get_activation(self.activation)(x) x = self.norm(x) x = self.ma_gating(x) # FFN block with residual connection and layer norm. x = x + self.ffn(x) * self.deepnorm_alpha x = self.out_norm(x) return x class MaGating(nnx.Module): """Applies multiplicative and additive gating.""" def __init__(self, feature_shape: tuple[int, ...], *, rngs: nnx.Rngs): self.mult_gate = Gating( feature_shape=feature_shape, additive=False, rngs=rngs ) self.add_gate = Gating( feature_shape=feature_shape, additive=True, rngs=rngs ) def __call__(self, x: jax.Array) -> jax.Array: return self.add_gate(self.mult_gate(x)) class Gating(nnx.Module): def __init__( self, feature_shape: tuple[int, ...], additive: bool = True, *, rngs: nnx.Rngs, ): self.additive = additive init_val = 0.0 if self.additive else 1.0 self.gate = nnx.Param( jnp.full(feature_shape, init_val, dtype=jnp.float32) ) def __call__(self, inputs: jax.Array) -> jax.Array: if self.additive: return inputs + self.gate.value effective_gate = jax.nn.relu(self.gate.value) return inputs * effective_gate ================================================ FILE: src/lczero_training/model/encoder.py ================================================ import math from typing import Optional import jax import jax.numpy as jnp from flax import nnx from flax.linen import initializers as flax_initializers from proto import model_config_pb2 from .shared import Ffn from .utils import get_activation class EncoderTower(nnx.Module): def __init__( self, *, in_features: int, config: model_config_pb2.EncoderConfig, defaults: model_config_pb2.DefaultsConfig, deepnorm_beta: float, rngs: nnx.Rngs, ): smolgen_shared_gen_dense = None assert config.HasField("smolgen") if config.HasField("smolgen"): smolgen_shared_gen_dense = nnx.Linear( in_features=config.smolgen.gen_size, out_features=64 * 64, use_bias=False, rngs=rngs, ) self.encoders = nnx.Sequential( *[ EncoderBlock( in_features=in_features, config=config, defaults=defaults, smol_gen_dense=smolgen_shared_gen_dense, deepnorm_beta=deepnorm_beta, rngs=rngs, ) for _ in range(config.num_blocks) ] ) def __call__(self, x: jax.Array) -> jax.Array: return self.encoders(x) class EncoderBlock(nnx.Module): """A single block of the transformer encoder.""" def __init__( self, *, in_features: int, config: model_config_pb2.EncoderConfig, defaults: model_config_pb2.DefaultsConfig, smol_gen_dense: Optional[nnx.Linear], deepnorm_beta: float, rngs: nnx.Rngs, ): assert (smol_gen_dense is not None) == config.HasField("smolgen") self.mha = MultiHeadAttention( in_features=in_features, config=config, defaults=defaults, smol_gen_dense=smol_gen_dense, deepnorm_beta=deepnorm_beta, rngs=rngs, ) self.alpha = math.pow(2.0 * config.num_blocks, -0.25) self.ln1 = nnx.LayerNorm(in_features, epsilon=1e-3, rngs=rngs) self.ffn = Ffn( in_features=in_features, hidden_features=config.dff, hidden_activation=defaults.ffn_activation, deepnorm_beta=deepnorm_beta, rngs=rngs, ) self.ln2 = nnx.LayerNorm(in_features, epsilon=1e-3, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: x = x + self.mha(x) * self.alpha out1 = self.ln1(x) ffn_out = self.ffn(out1) return self.ln2(out1 + ffn_out * self.alpha) class MultiHeadAttention(nnx.Module): """Multi-head attention module.""" def __init__( self, in_features: int, config: model_config_pb2.EncoderConfig, defaults: model_config_pb2.DefaultsConfig, smol_gen_dense: Optional[nnx.Linear], deepnorm_beta: float, *, rngs: nnx.Rngs, ): depth = config.d_model assert depth % config.heads == 0, ( "Model depth must be divisible by the number of heads." ) self.activation = defaults.activation self.depth = depth self.num_heads = config.heads self.q = nnx.Linear( in_features=in_features, out_features=depth, rngs=rngs ) self.k = nnx.Linear( in_features=in_features, out_features=depth, rngs=rngs ) deepnorm_init = flax_initializers.variance_scaling( scale=deepnorm_beta, mode="fan_avg", distribution="truncated_normal", ) self.v = nnx.Linear( in_features=in_features, out_features=depth, kernel_init=deepnorm_init, rngs=rngs, ) self.output_dense = nnx.Linear( in_features=depth, out_features=in_features, kernel_init=deepnorm_init, rngs=rngs, ) assert (smol_gen_dense is not None) == config.HasField("smolgen") self.smolgen: Optional[Smolgen] if smol_gen_dense is not None: self.smolgen = Smolgen( in_features=in_features, config=config.smolgen, defaults=defaults, heads=config.heads, weight_gen_dense=smol_gen_dense, rngs=rngs, ) else: self.smolgen = None def __call__(self, x: jax.Array) -> jax.Array: q, k, v = self.q(x), self.k(x), self.v(x) head_depth = self.depth // self.num_heads # Reshape for multi-head attention. q, k, v = ( t.reshape((-1, self.num_heads, head_depth)).transpose((1, 0, 2)) for t in (q, k, v) ) # Scaled dot-product attention. logits = jnp.einsum("...qd,...kd->...qk", q, k) logits /= jnp.sqrt(k.shape[-1]).astype(k.dtype) if self.smolgen is not None: logits += self.smolgen(x) attention_weights = nnx.softmax(logits, axis=-1) scaled_attention = jnp.matmul(attention_weights, v) # Reshape back to original dimensions. scaled_attention = scaled_attention.transpose((1, 0, 2)).reshape( (-1, self.depth) ) return self.output_dense(scaled_attention) class Smolgen(nnx.Module): """Smolgen module for generating attention biases.""" def __init__( self, in_features: int, config: model_config_pb2.SmolgenConfig, defaults: model_config_pb2.DefaultsConfig, heads: int, weight_gen_dense: nnx.Linear, *, rngs: nnx.Rngs, ): self.heads = heads self.compress = nnx.Linear( in_features=in_features, out_features=config.hidden_channels, use_bias=False, rngs=rngs, ) self.dense1 = nnx.Linear( in_features=config.hidden_channels * 64, out_features=config.hidden_size, rngs=rngs, ) self.ln1 = nnx.LayerNorm(config.hidden_size, epsilon=1e-3, rngs=rngs) self.dense2 = nnx.Linear( in_features=config.hidden_size, out_features=config.gen_size * heads, rngs=rngs, ) self.ln2 = nnx.LayerNorm( config.gen_size * heads, epsilon=1e-3, rngs=rngs ) self.weight_gen_dense = weight_gen_dense self.activation = config.activation or defaults.activation def __call__(self, x: jax.Array) -> jax.Array: compressed = self.compress(x).flatten() hidden = self.dense1(compressed) hidden = get_activation(self.activation)(hidden) hidden = self.ln1(hidden) gen_from = self.dense2(hidden) gen_from = get_activation(self.activation)(gen_from) gen_from = self.ln2(gen_from) gen_from = gen_from.reshape((self.heads, -1)) out = self.weight_gen_dense(gen_from) return out.reshape((self.heads, 64, 64)) ================================================ FILE: src/lczero_training/model/loss_function.py ================================================ from typing import Dict, List, Optional, Sequence, Tuple, Union, cast import jax import jax.numpy as jnp import optax from flax import nnx from jax.scipy.special import xlogy from lczero_training.training.state import TrainingSample from lczero_training.training.utils import make_weights_mask from proto.training_config_pb2 import ( LossConfig, MovesLeftLossConfig, PolicyLossConfig, RegularizationLossConfig, ValueCategoricalLossConfig, ValueErrorLossConfig, ValueLossConfig, ) from .model import LczeroModel, ModelPrediction def _compute_q_from_wdl(wdl_logits: jax.Array) -> jax.Array: """Compute Q value from WDL logits.""" wdl_probs = jax.nn.softmax(wdl_logits) q_weights = jnp.array([1.0, 0.0, -1.0]) return jnp.dot(wdl_probs, q_weights) class LossBase: def __init__( self, config: Union[ PolicyLossConfig, ValueLossConfig, MovesLeftLossConfig, ValueErrorLossConfig, ValueCategoricalLossConfig, ], ) -> None: self.head_name = config.head_name self.metric_name = config.metric_name or config.head_name self.weight = config.weight def __call__( self, predictions: ModelPrediction, sample: TrainingSample, ) -> jax.Array: raise NotImplementedError("Subclasses must implement __call__") class RegularizationLoss: """Computes regularization loss on model parameters.""" def __init__(self, config: RegularizationLossConfig) -> None: self.metric_name = config.metric_name or "l2" self.weight = config.weight self._selector = config.selector def __call__(self, model: LczeroModel) -> jax.Array: params = nnx.state(model, nnx.Param) mask = make_weights_mask(self._selector, params) masked_params = jax.tree.map( lambda p, m: p.value if m else jnp.zeros_like(p.value), params, mask, is_leaf=lambda x: isinstance(x, nnx.Variable), ) leaves = jax.tree.leaves(masked_params) return sum( (jnp.sum(jnp.square(leaf)) for leaf in leaves), jnp.array(0.0) ) class LczeroLoss: policy_losses: List["PolicyLoss"] value_losses: List["ValueLoss"] movesleft_losses: List["MovesLeftLoss"] value_error_losses: List["ValueErrorLoss"] value_categorical_losses: List["ValueCategoricalLoss"] regularization_losses: List["RegularizationLoss"] def __init__(self, config: LossConfig) -> None: self.config = config self.policy_losses = [ PolicyLoss(loss_config) for loss_config in config.policy ] self.value_losses = [ ValueLoss(loss_config) for loss_config in config.value ] self.movesleft_losses = [ MovesLeftLoss(loss_config) for loss_config in config.movesleft ] self.value_error_losses = [ ValueErrorLoss(loss_config) for loss_config in config.value_error ] self.value_categorical_losses = [ ValueCategoricalLoss(loss_config) for loss_config in config.value_categorical ] self.regularization_losses = [ RegularizationLoss(loss_config) for loss_config in config.regularization ] def _validate_no_duplicate_metrics( loss_type_name: str, losses: Sequence[Union[LossBase, RegularizationLoss]], ) -> None: seen = set() for name in (loss.metric_name for loss in losses): if name in seen: raise ValueError( f"Duplicate metric name: {loss_type_name}/{name}" ) seen.add(name) _validate_no_duplicate_metrics("policy", self.policy_losses) _validate_no_duplicate_metrics("value", self.value_losses) _validate_no_duplicate_metrics("movesleft", self.movesleft_losses) _validate_no_duplicate_metrics("value_error", self.value_error_losses) _validate_no_duplicate_metrics( "value_categorical", self.value_categorical_losses ) _validate_no_duplicate_metrics( "regularization", self.regularization_losses ) def __call__( self, model: LczeroModel, sample: TrainingSample, ) -> Tuple[jax.Array, Dict[str, jax.Array]]: # Run model forward pass. predictions = model(sample.inputs) unweighted_losses: Dict[str, jax.Array] = {} weighted_losses: List[jax.Array] = [] for policy_loss in self.policy_losses: loss = policy_loss(predictions, sample) unweighted_losses[f"policy/{policy_loss.metric_name}"] = loss weighted_losses.append(loss * policy_loss.weight) for value_loss in self.value_losses: loss = value_loss(predictions, sample) unweighted_losses[f"value/{value_loss.metric_name}"] = loss weighted_losses.append(loss * value_loss.weight) for movesleft_loss in self.movesleft_losses: loss = movesleft_loss(predictions, sample) unweighted_losses[f"movesleft/{movesleft_loss.metric_name}"] = loss weighted_losses.append(loss * movesleft_loss.weight) for value_error_loss in self.value_error_losses: loss = value_error_loss(predictions, sample) unweighted_losses[f"value_error/{value_error_loss.metric_name}"] = ( loss ) weighted_losses.append(loss * value_error_loss.weight) for value_categorical_loss in self.value_categorical_losses: loss = value_categorical_loss(predictions, sample) unweighted_losses[ f"value_categorical/{value_categorical_loss.metric_name}" ] = loss weighted_losses.append(loss * value_categorical_loss.weight) for reg_loss in self.regularization_losses: loss = reg_loss(model) unweighted_losses[f"regularization/{reg_loss.metric_name}"] = loss weighted_losses.append(loss * reg_loss.weight) data_loss = jnp.sum(jnp.array(weighted_losses)) return data_loss, unweighted_losses class ValueLoss(LossBase): def __init__(self, config: ValueLossConfig) -> None: super().__init__(config) self.value_type = config.value_type def __call__( self, predictions: ModelPrediction, sample: TrainingSample, ) -> jax.Array: value_pred = predictions.value[self.head_name] value_logits = value_pred[0] # Extract raw q/d from sample and compute WDL. value_q = sample.values[self.value_type, 0] value_d = sample.values[self.value_type, 1] # Compute WDL: w = (1 + q - d) / 2, l = (1 - q - d) / 2 value_w = (1.0 + value_q - value_d) / 2.0 value_l = (1.0 - value_q - value_d) / 2.0 value_wdl = jnp.stack([value_w, value_d, value_l], axis=-1) # The cross-entropy between the predicted value and the target value. value_cross_entropy = optax.softmax_cross_entropy( logits=value_logits, labels=jax.lax.stop_gradient(value_wdl) ) assert isinstance(value_cross_entropy, jax.Array) return value_cross_entropy class PolicyLoss(LossBase): def __init__(self, config: PolicyLossConfig): super().__init__(config) self.config = config if config.type == PolicyLossConfig.LOSS_TYPE_UNSPECIFIED: raise ValueError( f"Policy loss type must be specified for head '{config.head_name}'." ) self._loss_type = config.type temperature = config.temperature if temperature <= 0: temperature = 1.0 self._temperature = temperature # Store optimistic config if present. if config.HasField("optimistic"): opt = config.optimistic self.opt_value_head: Optional[str] = opt.value_head_name self.opt_value_type = opt.value_type self.opt_strength = opt.strength self.opt_eps = opt.eps self.opt_alpha = opt.alpha self.opt_propagate_gradients = opt.propagate_value_gradients else: self.opt_value_head = None def _apply_temperature_and_normalize( self, policy_targets: jax.Array ) -> jax.Array: if self._temperature == 1.0: return policy_targets # Apply temperature scaling. policy_targets = jnp.power(policy_targets, 1.0 / self._temperature) # Renormalize after temperature scaling. target_sum = jnp.sum(policy_targets, axis=-1, keepdims=True) safe_sum = jnp.where( target_sum > 0, target_sum, jnp.ones_like(target_sum) ) return policy_targets / safe_sum def _compute_optimistic_weight( self, value_pred: Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]], target_q: jax.Array, ) -> jax.Array: """Compute optimistic policy weight from value head predictions.""" wdl_logits = value_pred[0] error_pred = value_pred[1] assert error_pred is not None, ( "Error prediction required for optimistic weighting" ) # Optionally block gradients to value and error heads. if not self.opt_propagate_gradients: wdl_logits = jax.lax.stop_gradient(wdl_logits) error_pred = jax.lax.stop_gradient(error_pred) # Compute predicted Q from WDL. q_pred = _compute_q_from_wdl(wdl_logits) # Compute sigma and z-score. sigma = jnp.sqrt(error_pred.squeeze()) z = (target_q - q_pred) / (sigma + self.opt_eps) # Compute weight. return jax.nn.sigmoid((z - self.opt_strength) * self.opt_alpha) def __call__( self, predictions: ModelPrediction, sample: TrainingSample, ) -> jax.Array: policy_pred = predictions.policy[self.head_name] # Extract probabilities from sample. policy_targets = jnp.asarray( sample.probabilities, dtype=policy_pred.dtype ) if self.config.illegal_moves == PolicyLossConfig.MASK: policy_pred = jnp.where(policy_targets >= 0, policy_pred, -jnp.inf) # Zero out negative targets for illegal moves. policy_targets = jax.nn.relu(policy_targets) # Apply temperature scaling and renormalization if needed. policy_targets = self._apply_temperature_and_normalize(policy_targets) cross_entropy = cast( jax.Array, optax.safe_softmax_cross_entropy( logits=policy_pred, labels=policy_targets ), ) if self._loss_type == PolicyLossConfig.CROSS_ENTROPY: loss = cross_entropy elif self._loss_type == PolicyLossConfig.KL: loss = cross_entropy + jnp.sum( xlogy(policy_targets, policy_targets), axis=-1 ) else: raise AssertionError( f"Unknown policy loss type: {self._loss_type}." ) # Apply optimistic weighting if configured. if self.opt_value_head is not None: value_pred = predictions.value[self.opt_value_head] target_q = sample.values[self.opt_value_type, 0] loss = loss * self._compute_optimistic_weight(value_pred, target_q) return loss class MovesLeftLoss(LossBase): def __init__(self, config: MovesLeftLossConfig) -> None: super().__init__(config) self.value_type = config.value_type def __call__( self, predictions: ModelPrediction, sample: TrainingSample, ) -> jax.Array: movesleft_pred = predictions.movesleft[self.head_name] # Extract movesleft from sample. # sample.values shape: [6, 3], component 2 is movesleft. movesleft_targets = sample.values[self.value_type, 2] # Scale the loss to similar range as other losses. scale = 20.0 targets = movesleft_targets / scale scaled_predictions = movesleft_pred / scale # Huber loss huber_loss = optax.huber_loss( predictions=scaled_predictions, targets=targets, delta=10.0 / scale ) assert isinstance(huber_loss, jax.Array) return huber_loss.squeeze() class ValueErrorLoss(LossBase): def __init__(self, config: ValueErrorLossConfig) -> None: super().__init__(config) self.value_type = config.value_type self.propagate_value_gradients = config.propagate_value_gradients def __call__( self, predictions: ModelPrediction, sample: TrainingSample, ) -> jax.Array: value_pred = predictions.value[self.head_name] wdl_logits = value_pred[0] error_pred = value_pred[1] assert error_pred is not None # Convert WDL to Q value. predicted_q = _compute_q_from_wdl(wdl_logits) # Get target Q value. target_q = sample.values[self.value_type, 0] # Compute actual squared error. actual_squared_error = jnp.square(predicted_q - target_q) # Optionally block gradients to WDL head. if not self.propagate_value_gradients: actual_squared_error = jax.lax.stop_gradient(actual_squared_error) # MSE between error prediction and actual error. loss = jnp.square(error_pred - actual_squared_error) return loss.squeeze() class ValueCategoricalLoss(LossBase): def __init__(self, config: ValueCategoricalLossConfig) -> None: super().__init__(config) self.value_type = config.value_type def __call__( self, predictions: ModelPrediction, sample: TrainingSample, ) -> jax.Array: value_pred = predictions.value[self.head_name] categorical_logits = value_pred[2] assert categorical_logits is not None # Get target Q value from sample. target_q = sample.values[self.value_type, 0] # Convert Q to bucket index: map [-1, 1) to [0, num_buckets). num_buckets = categorical_logits.shape[-1] bucket_index = jnp.floor((target_q + 1.0) / 2.0 * num_buckets).astype( jnp.int32 ) bucket_index = jnp.clip(bucket_index, 0, num_buckets - 1) # Create one-hot target. target_one_hot = jax.nn.one_hot(bucket_index, num_buckets) # Compute softmax cross-entropy. loss = optax.softmax_cross_entropy( logits=categorical_logits, labels=jax.lax.stop_gradient(target_one_hot), ) assert isinstance(loss, jax.Array) return loss ================================================ FILE: src/lczero_training/model/model.py ================================================ import dataclasses import math from typing import Optional, Tuple import jax import jax.numpy as jnp from flax import nnx from proto import model_config_pb2 from .embedding import Embedding from .encoder import EncoderTower from .movesleft_head import MovesLeftHead from .policy_head import PolicyHead from .utils import get_dtype from .value_head import ValueHead @jax.tree_util.register_dataclass @dataclasses.dataclass class ModelPrediction: """Output predictions from LczeroModel. Fields: value: Dictionary mapping head names to value prediction tuples. policy: Dictionary mapping head names to policy logits. movesleft: Dictionary mapping head names to moves-left predictions. """ value: dict[str, Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]] policy: dict[str, jax.Array] movesleft: dict[str, jax.Array] class LczeroModel(nnx.Module): def __init__(self, config: model_config_pb2.ModelConfig, *, rngs: nnx.Rngs): self.config = config self._input_channels = 112 deepnorm_beta = math.pow(8.0 * config.encoder.num_blocks, -0.25) self.embedding = Embedding( input_channels=self._input_channels, config=config.embedding, defaults=config.defaults, deepnorm_alpha=math.pow(2.0 * config.encoder.num_blocks, -0.25), deepnorm_beta=deepnorm_beta, rngs=rngs, ) assert self.config.encoder.num_blocks > 0 self.encoders = EncoderTower( in_features=config.embedding.embedding_size, config=config.encoder, defaults=config.defaults, deepnorm_beta=deepnorm_beta, rngs=rngs, ) self.value_heads = nnx.Dict( { head_config.name: ValueHead( in_features=config.embedding.embedding_size, config=head_config, defaults=config.defaults, rngs=rngs, ) for head_config in config.value_head } ) # Named to appear before 'policy_heads' alphabetically in pytree state. # This ensures shared embedding appears at parent level during # serialization. self.policy_embedding_shared: Optional[nnx.Linear] if config.HasField("shared_policy_embedding_size"): self.policy_embedding_shared = nnx.Linear( in_features=config.embedding.embedding_size, out_features=config.shared_policy_embedding_size, rngs=rngs, ) else: self.policy_embedding_shared = None self.policy_heads = nnx.Dict( { head_config.name: PolicyHead( in_features=config.embedding.embedding_size, config=head_config, defaults=config.defaults, shared_embedding=self.policy_embedding_shared, rngs=rngs, ) for head_config in config.policy_head } ) self.movesleft_heads = nnx.Dict( { head_config.name: MovesLeftHead( in_features=config.embedding.embedding_size, config=head_config, defaults=config.defaults, rngs=rngs, ) for head_config in config.movesleft_head } ) def __call__(self, x: jax.Array) -> ModelPrediction: x = jnp.astype(x, get_dtype(self.config.defaults.compute_dtype)) x = jnp.transpose(x, (1, 2, 0)) x = jnp.reshape(x, (64, self._input_channels)) x = self.embedding(x) x = self.encoders(x) value = {name: head(x) for name, head in self.value_heads.items()} policy = {name: head(x) for name, head in self.policy_heads.items()} movesleft = { name: head(x) for name, head in self.movesleft_heads.items() } return ModelPrediction(value=value, policy=policy, movesleft=movesleft) ================================================ FILE: src/lczero_training/model/movesleft_head.py ================================================ import jax from flax import nnx from proto import model_config_pb2 from .utils import get_activation class MovesLeftHead(nnx.Module): def __init__( self, in_features: int, config: model_config_pb2.MovesLeftHeadConfig, defaults: model_config_pb2.DefaultsConfig, *, rngs: nnx.Rngs, ): self.activation = defaults.activation self.embed = nnx.Linear( in_features=in_features, out_features=config.num_channels, rngs=rngs, ) self.dense1 = nnx.Linear( in_features=config.num_channels * 64, out_features=128, rngs=rngs, ) self.out = nnx.Linear( in_features=128, out_features=1, rngs=rngs, ) def __call__(self, x: jax.Array) -> jax.Array: x = self.embed(x).flatten() x = get_activation(self.activation)(x) x = self.dense1(x) x = get_activation(self.activation)(x) x = self.out(x) return nnx.relu(x) ================================================ FILE: src/lczero_training/model/policy_head.py ================================================ import math from typing import Optional import jax import jax.numpy as jnp from flax import nnx from proto import model_config_pb2 from .utils import get_activation # , get_policy_map class PolicyHead(nnx.Module): def __init__( self, in_features: int, config: model_config_pb2.PolicyHeadConfig, defaults: model_config_pb2.DefaultsConfig, shared_embedding: Optional[nnx.Linear] = None, *, rngs: nnx.Rngs, ): assert (shared_embedding is not None) != config.HasField( "embedding_size" ) self.activation = defaults.activation if shared_embedding is not None: self.tokens = shared_embedding embedding_size = shared_embedding.out_features else: self.tokens = nnx.Linear( in_features=in_features, out_features=config.embedding_size, rngs=rngs, ) embedding_size = config.embedding_size self.q = nnx.Linear( in_features=embedding_size, out_features=config.d_model, rngs=rngs, ) self.k = nnx.Linear( in_features=embedding_size, out_features=config.d_model, rngs=rngs, ) self.dk = math.sqrt(config.d_model) self.promotion_dense = nnx.Linear( in_features=config.d_model, out_features=4, use_bias=False, rngs=rngs, ) def __call__(self, x: jax.Array) -> jax.Array: x = self.tokens(x) x = get_activation(self.activation)(x) q = self.q(x) k = self.k(x) qk = jnp.einsum("qd,kd->qk", q, k) promotion_keys = k[-8:, :] promotion_offsets = self.promotion_dense(promotion_keys) promotion_offsets = promotion_offsets.transpose((1, 0)) * self.dk # knight offset is added to the other three promotion_offsets = promotion_offsets[:3, :] + promotion_offsets[3:4, :] n_promo_logits = qk[-16:-8, -8:] q_promo_logits = jnp.expand_dims( n_promo_logits + promotion_offsets[0:1, :], axis=-1 ) r_promo_logits = jnp.expand_dims( n_promo_logits + promotion_offsets[1:2, :], axis=-1 ) b_promo_logits = jnp.expand_dims( n_promo_logits + promotion_offsets[2:3, :], axis=-1 ) promotion_logits = jnp.concatenate( [q_promo_logits, r_promo_logits, b_promo_logits], axis=-1 ) policy_attn_logits = qk / self.dk promotion_logits = promotion_logits.reshape((8, 24)) / self.dk logits = jnp.concatenate( [policy_attn_logits.flatten(), promotion_logits.flatten()], axis=-1 ) return logits[_policy_map] # fmt: off _policy_map = jnp.array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17, 18, 24, 27, 32, 36, 40, 45, 48, 54, 56, 63, 64, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 80, 81, 82, 83, 89, 92, 97, 101, 105, 110, 113, 119, 121, 128, 129, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 144, 145, 146, 147, 148, 154, 157, 162, 166, 170, 175, 178, 186, 192, 193, 194, 196, 197, 198, 199, 201, 202, 203, 204, 205, 209, 210, 211, 212, 213, 216, 219, 222, 227, 231, 235, 243, 251, 256, 257, 258, 259, 261, 262, 263, 266, 267, 268, 269, 270, 274, 275, 276, 277, 278, 281, 284, 287, 288, 292, 300, 308, 316, 320, 321, 322, 323, 324, 326, 327, 331, 332, 333, 334, 335, 339, 340, 341, 342, 343, 346, 349, 353, 357, 360, 365, 373, 381, 384, 385, 386, 387, 388, 389, 391, 396, 397, 398, 399, 404, 405, 406, 407, 411, 414, 418, 422, 425, 430, 432, 438, 446, 448, 449, 450, 451, 452, 453, 454, 461, 462, 463, 469, 470, 471, 476, 479, 483, 487, 490, 495, 497, 503, 504, 511, 512, 513, 514, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 536, 537, 538, 544, 547, 552, 556, 560, 565, 568, 574, 576, 577, 578, 579, 584, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 600, 601, 602, 603, 609, 612, 617, 621, 625, 630, 633, 639, 640, 641, 642, 643, 644, 648, 649, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 664, 665, 666, 667, 668, 674, 677, 682, 686, 690, 695, 698, 705, 706, 707, 708, 709, 712, 713, 714, 716, 717, 718, 719, 721, 722, 723, 724, 725, 729, 730, 731, 732, 733, 736, 739, 742, 747, 751, 755, 763, 770, 771, 772, 773, 774, 776, 777, 778, 779, 781, 782, 783, 786, 787, 788, 789, 790, 794, 795, 796, 797, 798, 801, 804, 807, 808, 812, 820, 828, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 846, 847, 851, 852, 853, 854, 855, 859, 860, 861, 862, 863, 866, 869, 873, 877, 880, 885, 893, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 911, 916, 917, 918, 919, 924, 925, 926, 927, 931, 934, 938, 942, 945, 950, 952, 958, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 981, 982, 983, 989, 990, 991, 996, 999, 1003, 1007, 1010, 1015, 1017, 1023, 1024, 1025, 1026, 1032, 1033, 1034, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1056, 1057, 1058, 1064, 1067, 1072, 1076, 1080, 1085, 1088, 1089, 1090, 1091, 1096, 1097, 1098, 1099, 1104, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1120, 1121, 1122, 1123, 1129, 1132, 1137, 1141, 1145, 1150, 1152, 1153, 1154, 1155, 1156, 1160, 1161, 1162, 1163, 1164, 1168, 1169, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1180, 1184, 1185, 1186, 1187, 1188, 1194, 1197, 1202, 1206, 1210, 1215, 1217, 1218, 1219, 1220, 1221, 1225, 1226, 1227, 1228, 1229, 1232, 1233, 1234, 1236, 1237, 1238, 1239, 1241, 1242, 1243, 1244, 1245, 1249, 1250, 1251, 1252, 1253, 1256, 1259, 1262, 1267, 1271, 1275, 1282, 1283, 1284, 1285, 1286, 1290, 1291, 1292, 1293, 1294, 1296, 1297, 1298, 1299, 1301, 1302, 1303, 1306, 1307, 1308, 1309, 1310, 1314, 1315, 1316, 1317, 1318, 1321, 1324, 1327, 1328, 1332, 1340, 1347, 1348, 1349, 1350, 1351, 1355, 1356, 1357, 1358, 1359, 1360, 1361, 1362, 1363, 1364, 1366, 1367, 1371, 1372, 1373, 1374, 1375, 1379, 1380, 1381, 1382, 1383, 1386, 1389, 1393, 1397, 1400, 1405, 1412, 1413, 1414, 1415, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1431, 1436, 1437, 1438, 1439, 1444, 1445, 1446, 1447, 1451, 1454, 1458, 1462, 1465, 1470, 1477, 1478, 1479, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1494, 1501, 1502, 1503, 1509, 1510, 1511, 1516, 1519, 1523, 1527, 1530, 1535, 1536, 1539, 1544, 1545, 1546, 1552, 1553, 1554, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1576, 1577, 1578, 1584, 1587, 1592, 1596, 1601, 1604, 1608, 1609, 1610, 1611, 1616, 1617, 1618, 1619, 1624, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1640, 1641, 1642, 1643, 1649, 1652, 1657, 1661, 1666, 1669, 1672, 1673, 1674, 1675, 1676, 1680, 1681, 1682, 1683, 1684, 1688, 1689, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1704, 1705, 1706, 1707, 1708, 1714, 1717, 1722, 1726, 1728, 1731, 1734, 1737, 1738, 1739, 1740, 1741, 1745, 1746, 1747, 1748, 1749, 1752, 1753, 1754, 1756, 1757, 1758, 1759, 1761, 1762, 1763, 1764, 1765, 1769, 1770, 1771, 1772, 1773, 1776, 1779, 1782, 1787, 1791, 1793, 1796, 1799, 1802, 1803, 1804, 1805, 1806, 1810, 1811, 1812, 1813, 1814, 1816, 1817, 1818, 1819, 1821, 1822, 1823, 1826, 1827, 1828, 1829, 1830, 1834, 1835, 1836, 1837, 1838, 1841, 1844, 1847, 1848, 1852, 1858, 1861, 1867, 1868, 1869, 1870, 1871, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883, 1884, 1886, 1887, 1891, 1892, 1893, 1894, 1895, 1899, 1900, 1901, 1902, 1903, 1906, 1909, 1913, 1917, 1923, 1926, 1932, 1933, 1934, 1935, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1951, 1956, 1957, 1958, 1959, 1964, 1965, 1966, 1967, 1971, 1974, 1978, 1982, 1988, 1991, 1997, 1998, 1999, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2021, 2022, 2023, 2029, 2030, 2031, 2036, 2039, 2043, 2047, 2048, 2052, 2056, 2059, 2064, 2065, 2066, 2072, 2073, 2074, 2081, 2082, 2083, 2084, 2085, 2086, 2087, 2088, 2089, 2090, 2096, 2097, 2098, 2104, 2107, 2113, 2117, 2121, 2124, 2128, 2129, 2130, 2131, 2136, 2137, 2138, 2139, 2144, 2146, 2147, 2148, 2149, 2150, 2151, 2152, 2153, 2154, 2155, 2160, 2161, 2162, 2163, 2169, 2172, 2178, 2182, 2186, 2189, 2192, 2193, 2194, 2195, 2196, 2200, 2201, 2202, 2203, 2204, 2208, 2209, 2211, 2212, 2213, 2214, 2215, 2216, 2217, 2218, 2219, 2220, 2224, 2225, 2226, 2227, 2228, 2234, 2237, 2243, 2247, 2248, 2251, 2254, 2257, 2258, 2259, 2260, 2261, 2265, 2266, 2267, 2268, 2269, 2272, 2273, 2274, 2276, 2277, 2278, 2279, 2281, 2282, 2283, 2284, 2285, 2289, 2290, 2291, 2292, 2293, 2296, 2299, 2302, 2304, 2308, 2313, 2316, 2319, 2322, 2323, 2324, 2325, 2326, 2330, 2331, 2332, 2333, 2334, 2336, 2337, 2338, 2339, 2341, 2342, 2343, 2346, 2347, 2348, 2349, 2350, 2354, 2355, 2356, 2357, 2358, 2361, 2364, 2367, 2369, 2373, 2378, 2381, 2387, 2388, 2389, 2390, 2391, 2395, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2406, 2407, 2411, 2412, 2413, 2414, 2415, 2419, 2420, 2421, 2422, 2423, 2426, 2429, 2434, 2438, 2443, 2446, 2452, 2453, 2454, 2455, 2460, 2461, 2462, 2463, 2464, 2465, 2466, 2467, 2468, 2469, 2471, 2476, 2477, 2478, 2479, 2484, 2485, 2486, 2487, 2491, 2494, 2499, 2503, 2508, 2511, 2517, 2518, 2519, 2525, 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2541, 2542, 2543, 2549, 2550, 2551, 2556, 2559, 2560, 2565, 2568, 2572, 2576, 2579, 2584, 2585, 2586, 2592, 2593, 2594, 2601, 2602, 2603, 2604, 2605, 2606, 2607, 2608, 2609, 2610, 2616, 2617, 2618, 2625, 2630, 2633, 2637, 2641, 2644, 2648, 2649, 2650, 2651, 2656, 2657, 2658, 2659, 2664, 2666, 2667, 2668, 2669, 2670, 2671, 2672, 2673, 2674, 2675, 2680, 2681, 2682, 2683, 2690, 2695, 2698, 2702, 2706, 2709, 2712, 2713, 2714, 2715, 2716, 2720, 2721, 2722, 2723, 2724, 2728, 2729, 2731, 2732, 2733, 2734, 2735, 2736, 2737, 2738, 2739, 2740, 2744, 2745, 2746, 2747, 2748, 2755, 2763, 2767, 2768, 2771, 2774, 2777, 2778, 2779, 2780, 2781, 2785, 2786, 2787, 2788, 2789, 2792, 2793, 2794, 2796, 2797, 2798, 2799, 2801, 2802, 2803, 2804, 2805, 2809, 2810, 2811, 2812, 2813, 2820, 2824, 2828, 2833, 2836, 2839, 2842, 2843, 2844, 2845, 2846, 2850, 2851, 2852, 2853, 2854, 2856, 2857, 2858, 2859, 2861, 2862, 2863, 2866, 2867, 2868, 2869, 2870, 2874, 2875, 2876, 2877, 2878, 2880, 2885, 2889, 2893, 2898, 2901, 2907, 2908, 2909, 2910, 2911, 2915, 2916, 2917, 2918, 2919, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2931, 2932, 2933, 2934, 2935, 2939, 2940, 2941, 2942, 2943, 2945, 2950, 2954, 2958, 2963, 2966, 2972, 2973, 2974, 2975, 2980, 2981, 2982, 2983, 2984, 2985, 2986, 2987, 2988, 2989, 2991, 2996, 2997, 2998, 2999, 3004, 3005, 3006, 3007, 3010, 3015, 3019, 3023, 3028, 3031, 3037, 3038, 3039, 3045, 3046, 3047, 3048, 3049, 3050, 3051, 3052, 3053, 3054, 3061, 3062, 3063, 3069, 3070, 3071, 3072, 3078, 3080, 3085, 3088, 3092, 3096, 3099, 3104, 3105, 3106, 3112, 3113, 3114, 3121, 3122, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3130, 3137, 3143, 3145, 3150, 3153, 3157, 3161, 3164, 3168, 3169, 3170, 3171, 3176, 3177, 3178, 3179, 3184, 3186, 3187, 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3202, 3210, 3215, 3218, 3222, 3226, 3229, 3232, 3233, 3234, 3235, 3236, 3240, 3241, 3242, 3243, 3244, 3248, 3249, 3251, 3252, 3253, 3254, 3255, 3256, 3257, 3258, 3259, 3260, 3267, 3275, 3283, 3287, 3288, 3291, 3294, 3297, 3298, 3299, 3300, 3301, 3305, 3306, 3307, 3308, 3309, 3312, 3313, 3314, 3316, 3317, 3318, 3319, 3321, 3322, 3323, 3324, 3325, 3332, 3340, 3344, 3348, 3353, 3356, 3359, 3362, 3363, 3364, 3365, 3366, 3370, 3371, 3372, 3373, 3374, 3376, 3377, 3378, 3379, 3381, 3382, 3383, 3386, 3387, 3388, 3389, 3390, 3397, 3400, 3405, 3409, 3413, 3418, 3421, 3427, 3428, 3429, 3430, 3431, 3435, 3436, 3437, 3438, 3439, 3440, 3441, 3442, 3443, 3444, 3446, 3447, 3451, 3452, 3453, 3454, 3455, 3456, 3462, 3465, 3470, 3474, 3478, 3483, 3486, 3492, 3493, 3494, 3495, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509, 3511, 3516, 3517, 3518, 3519, 3521, 3527, 3530, 3535, 3539, 3543, 3548, 3551, 3557, 3558, 3559, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3581, 3582, 3583, 3584, 3591, 3592, 3598, 3600, 3605, 3608, 3612, 3616, 3619, 3624, 3625, 3626, 3632, 3633, 3634, 3641, 3642, 3643, 3644, 3645, 3646, 3647, 3649, 3657, 3663, 3665, 3670, 3673, 3677, 3681, 3684, 3688, 3689, 3690, 3691, 3696, 3697, 3698, 3699, 3704, 3706, 3707, 3708, 3709, 3710, 3711, 3714, 3722, 3730, 3735, 3738, 3742, 3746, 3749, 3752, 3753, 3754, 3755, 3756, 3760, 3761, 3762, 3763, 3764, 3768, 3769, 3771, 3772, 3773, 3774, 3775, 3779, 3787, 3795, 3803, 3807, 3808, 3811, 3814, 3817, 3818, 3819, 3820, 3821, 3825, 3826, 3827, 3828, 3829, 3832, 3833, 3834, 3836, 3837, 3838, 3839, 3844, 3852, 3860, 3864, 3868, 3873, 3876, 3879, 3882, 3883, 3884, 3885, 3886, 3890, 3891, 3892, 3893, 3894, 3896, 3897, 3898, 3899, 3901, 3902, 3903, 3909, 3917, 3920, 3925, 3929, 3933, 3938, 3941, 3947, 3948, 3949, 3950, 3951, 3955, 3956, 3957, 3958, 3959, 3960, 3961, 3962, 3963, 3964, 3966, 3967, 3974, 3976, 3982, 3985, 3990, 3994, 3998, 4003, 4006, 4012, 4013, 4014, 4015, 4020, 4021, 4022, 4023, 4024, 4025, 4026, 4027, 4028, 4029, 4031, 4032, 4039, 4041, 4047, 4050, 4055, 4059, 4063, 4068, 4071, 4077, 4078, 4079, 4085, 4086, 4087, 4088, 4089, 4090, 4091, 4092, 4093, 4094, 4096, 4097, 4098, 4099, 4100, 4101, 4120, 4121, 4122, 4123, 4124, 4125, 4126, 4127, 4128, 4147, 4148, 4149, 4150, 4151, 4152, 4153, 4154, 4155, 4174, 4175, 4176, 4177, 4178, 4179, 4180, 4181, 4182, 4201, 4202, 4203, 4204, 4205, 4206, 4207, 4208, 4209, 4228, 4229, 4230, 4231, 4232, 4233, 4234, 4235, 4236, 4255, 4256, 4257, 4258, 4259, 4260, 4261, 4262, 4263, 4282, 4283, 4284, 4285, 4286, 4287, ], jnp.int32) ================================================ FILE: src/lczero_training/model/shared.py ================================================ import jax from flax import nnx from flax.linen import initializers as flax_initializers from proto import net_pb2 from .utils import get_activation class Ffn(nnx.Module): def __init__( self, in_features: int, hidden_features: int, hidden_activation: net_pb2.NetworkFormat.ActivationFunction, deepnorm_beta: float, *, rngs: nnx.Rngs, ): deepnorm_init = flax_initializers.variance_scaling( scale=deepnorm_beta, mode="fan_avg", distribution="truncated_normal", ) out_features = in_features self.linear1 = nnx.Linear( in_features=in_features, out_features=hidden_features, kernel_init=deepnorm_init, rngs=rngs, ) self.activation = hidden_activation self.linear2 = nnx.Linear( in_features=hidden_features, out_features=out_features, kernel_init=deepnorm_init, rngs=rngs, ) def __call__(self, x: jax.Array) -> jax.Array: x = self.linear1(x) x = get_activation(self.activation)(x) x = self.linear2(x) return x ================================================ FILE: src/lczero_training/model/utils.py ================================================ from typing import Any import jax.numpy as jnp from flax import nnx from jax.nn import mish from proto import net_pb2 from proto.hlo_pb2 import XlaShapeProto def get_activation( activation: net_pb2.NetworkFormat.ActivationFunction, ) -> Any: return { net_pb2.NetworkFormat.ACTIVATION_MISH: mish, net_pb2.NetworkFormat.ACTIVATION_RELU: nnx.relu, net_pb2.NetworkFormat.ACTIVATION_NONE: lambda x: x, net_pb2.NetworkFormat.ACTIVATION_TANH: nnx.tanh, net_pb2.NetworkFormat.ACTIVATION_SIGMOID: nnx.sigmoid, net_pb2.NetworkFormat.ACTIVATION_SELU: nnx.selu, net_pb2.NetworkFormat.ACTIVATION_SWISH: nnx.swish, net_pb2.NetworkFormat.ACTIVATION_SOFTMAX: nnx.softmax, }[activation] def get_dtype(dtype: XlaShapeProto.Type) -> jnp.dtype: return { XlaShapeProto.PRED: jnp.bool_, XlaShapeProto.S4: jnp.int4, XlaShapeProto.S8: jnp.int8, XlaShapeProto.S16: jnp.int16, XlaShapeProto.S32: jnp.int32, XlaShapeProto.S64: jnp.int64, XlaShapeProto.U4: jnp.uint4, XlaShapeProto.U8: jnp.uint8, XlaShapeProto.U16: jnp.uint16, XlaShapeProto.U32: jnp.uint32, XlaShapeProto.U64: jnp.uint64, XlaShapeProto.F16: jnp.float16, XlaShapeProto.F32: jnp.float32, XlaShapeProto.BF16: jnp.bfloat16, XlaShapeProto.F64: jnp.float64, XlaShapeProto.F8E5M2: jnp.float8_e5m2, XlaShapeProto.F8E4M3FN: jnp.float8_e4m3fn, XlaShapeProto.F8E4M3B11FNUZ: jnp.float8_e4m3b11fnuz, XlaShapeProto.F8E5M2FNUZ: jnp.float8_e5m2fnuz, XlaShapeProto.F8E4M3FNUZ: jnp.float8_e4m3fnuz, XlaShapeProto.C64: jnp.complex64, XlaShapeProto.C128: jnp.complex128, }[dtype] ================================================ FILE: src/lczero_training/model/value_head.py ================================================ from typing import Optional, Tuple import jax from flax import nnx from proto import model_config_pb2 from .utils import get_activation class ValueHead(nnx.Module): def __init__( self, in_features: int, config: model_config_pb2.ValueHeadConfig, defaults: model_config_pb2.DefaultsConfig, rngs: nnx.Rngs, ): self.activation = defaults.activation self.has_error_output = config.has_error_output self.num_categorical_buckets = config.num_categorical_buckets self.embed = nnx.Linear( in_features=in_features, out_features=config.num_channels, rngs=rngs, ) self.dense1 = nnx.Linear( in_features=config.num_channels * 64, out_features=128, rngs=rngs, ) self.wdl = nnx.Linear( in_features=128, out_features=3, rngs=rngs, ) if self.has_error_output: self.error = nnx.Linear( in_features=128, out_features=1, rngs=rngs, ) if self.num_categorical_buckets > 0: self.categorical = nnx.Linear( in_features=128, out_features=self.num_categorical_buckets, rngs=rngs, ) def __call__( self, x: jax.Array ) -> Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: x = self.embed(x).flatten() x = get_activation(self.activation)(x) x = self.dense1(x) x = get_activation(self.activation)(x) wdl = self.wdl(x) error = nnx.sigmoid(self.error(x)) if self.has_error_output else None categorical = ( self.categorical(x) if self.num_categorical_buckets > 0 else None ) return (wdl, error, categorical) def predict(self, x: jax.Array) -> jax.Array: return nnx.softmax(self(x)[0]) ================================================ FILE: src/lczero_training/py.typed ================================================ ================================================ FILE: src/lczero_training/tests/test_protobuf.py ================================================ """Test protobuf compilation and functionality.""" def test_protobuf_import() -> None: """Test that protobuf files can be imported.""" import proto.data_loader_config_pb2 as data_loader_config_pb2 import proto.model_config_pb2 as model_config_pb2 import proto.root_config_pb2 as root_config_pb2 import proto.training_config_pb2 as training_config_pb2 import proto.training_metrics_pb2 as training_metrics_pb2 # Test creating config objects data_loader_config = data_loader_config_pb2.DataLoaderConfig() assert data_loader_config is not None root_config = root_config_pb2.RootConfig() assert root_config is not None training_config = training_config_pb2.TrainingConfig() assert training_config is not None model_config = model_config_pb2.ModelConfig() assert model_config is not None metrics = training_metrics_pb2.DataLoaderMetricsProto() assert metrics is not None def test_protobuf_functionality() -> None: """Test basic protobuf functionality.""" import proto.data_loader_config_pb2 as data_loader_config_pb2 import proto.root_config_pb2 as root_config_pb2 # Create a config and set some values config = data_loader_config_pb2.DataLoaderConfig() stage = config.stage.add() stage.name = "file_path_provider" stage.file_path_provider.directory = "/test/path" # Serialize and deserialize serialized = config.SerializeToString() assert len(serialized) > 0 config2 = data_loader_config_pb2.DataLoaderConfig() config2.ParseFromString(serialized) assert config2.stage[0].file_path_provider.directory == "/test/path" # Test RootConfig functionality root_config = root_config_pb2.RootConfig() root_config.name = "test_config" stage_config = root_config.data_loader.stage.add() stage_config.name = "file_path_provider" stage_config.file_path_provider.directory = "/test/path" # Serialize and deserialize root config root_serialized = root_config.SerializeToString() assert len(root_serialized) > 0 root_config2 = root_config_pb2.RootConfig() root_config2.ParseFromString(root_serialized) assert root_config2.name == "test_config" assert ( root_config2.data_loader.stage[0].file_path_provider.directory == "/test/path" ) ================================================ FILE: src/lczero_training/tests/test_protocol_registry.py ================================================ """Test script for the protocol registry system.""" from dataclasses import dataclass from typing import Any import pytest from lczero_training.daemon.protocol.registry import ( CLASS_TO_TYPE_MAP, TYPE_TO_CLASS_MAP, register, ) @pytest.fixture(autouse=True) def clear_registry() -> Any: """Clear registry maps before each test.""" TYPE_TO_CLASS_MAP.clear() CLASS_TO_TYPE_MAP.clear() yield TYPE_TO_CLASS_MAP.clear() CLASS_TO_TYPE_MAP.clear() def test_basic_registration() -> None: """Test basic event type registration.""" @register("test_event") @dataclass class BasicPayload: content: str # Check forward mapping assert TYPE_TO_CLASS_MAP["test_event"] == BasicPayload # Check reverse mapping assert CLASS_TO_TYPE_MAP[BasicPayload] == "test_event" def test_duplicate_event_type() -> None: """Test that duplicate event types are rejected.""" @register("duplicate_event") @dataclass class FirstPayload: data: str with pytest.raises( ValueError, match=r".*duplicate_event.*already registered.*" ): @register("duplicate_event") # Should fail @dataclass class SecondPayload: other_data: int def test_duplicate_class() -> None: """Test that duplicate classes are rejected.""" @dataclass class PayloadClass: data: str # Register once register("first_event")(PayloadClass) # Try to register same class again - should fail with pytest.raises( ValueError, match=r".*PayloadClass.*already registered.*" ): register("second_event")(PayloadClass) def test_non_class_registration() -> None: """Test that non-classes are rejected.""" with pytest.raises(TypeError, match=r".*can only be used on classes.*"): # Try to register a string instead of a class @register("invalid_event") # type: ignore[arg-type] def not_a_class() -> None: pass def test_multiple_registrations() -> None: """Test multiple valid registrations work correctly.""" @register("event_one") @dataclass class PayloadOne: data: str @register("event_two") @dataclass class PayloadTwo: value: int @register("event_three") @dataclass class PayloadThree: items: list # Check all mappings exist assert TYPE_TO_CLASS_MAP["event_one"] == PayloadOne assert TYPE_TO_CLASS_MAP["event_two"] == PayloadTwo assert TYPE_TO_CLASS_MAP["event_three"] == PayloadThree assert CLASS_TO_TYPE_MAP[PayloadOne] == "event_one" assert CLASS_TO_TYPE_MAP[PayloadTwo] == "event_two" assert CLASS_TO_TYPE_MAP[PayloadThree] == "event_three" # Check we have exactly 3 entries in each map assert len(TYPE_TO_CLASS_MAP) == 3 assert len(CLASS_TO_TYPE_MAP) == 3 def test_registry_persistence() -> None: """Test that registry persists across imports.""" @register("persistent_event") @dataclass class PersistentPayload: data: str # Re-import the module from lczero_training.daemon.protocol.registry import ( CLASS_TO_TYPE_MAP as imported_class_map, ) from lczero_training.daemon.protocol.registry import ( TYPE_TO_CLASS_MAP as imported_type_map, ) # Check the registration persists assert imported_type_map["persistent_event"] == PersistentPayload assert imported_class_map[PersistentPayload] == "persistent_event" ================================================ FILE: src/lczero_training/tests/test_weights_tool.py ================================================ """Test weights tool arithmetic operations.""" import os import tempfile import numpy as np import proto.net_pb2 as net_pb2 from lczero_training.tools.weight_wrappers import NetWrapper from lczero_training.tools.weights_tool import load_weights, save_weights def test_weights_arithmetic() -> None: """Test arithmetic operations on simple networks.""" # Create network A with a single layer containing value 10.0. net_a = net_pb2.Net() net_a.format.weights_encoding = net_pb2.Format.LINEAR16 net_a.weights.ip1_val_w.min_val = 10.0 net_a.weights.ip1_val_w.max_val = 10.0 # LINEAR16 encoding: value 10.0 maps to uint16 value 32767 (mid-point). net_a.weights.ip1_val_w.params = np.array( [32767], dtype=np.uint16 ).tobytes() # Create network B with a single layer containing value 20.0. net_b = net_pb2.Net() net_b.format.weights_encoding = net_pb2.Format.LINEAR16 net_b.weights.ip1_val_w.min_val = 20.0 net_b.weights.ip1_val_w.max_val = 20.0 net_b.weights.ip1_val_w.params = np.array( [32767], dtype=np.uint16 ).tobytes() # Wrap the networks. wrapper_a = NetWrapper(net_a) wrapper_b = NetWrapper(net_b) # Compute: output = 0.2*A + 0.8*B # Expected: 0.2*10 + 0.8*20 = 2 + 16 = 18 output = 0.2 * wrapper_a + 0.8 * wrapper_b # Check the result. result_value = output.weights.ip1_val_w.value expected = 18.0 assert result_value.shape == (1,), ( f"Expected shape (1,), got {result_value.shape}" ) assert np.isclose(result_value[0], expected, rtol=1e-4), ( f"Expected {expected}, got {result_value[0]}" ) def test_policy_head_replacement() -> None: """Test that assigning policy_heads actually replaces the data.""" # Create network A with policy head value 10.0. net_a = net_pb2.Net() net_a.format.weights_encoding = net_pb2.Format.LINEAR16 net_a.weights.policy_heads.ip_pol_w.min_val = 10.0 net_a.weights.policy_heads.ip_pol_w.max_val = 10.0 net_a.weights.policy_heads.ip_pol_w.params = np.array( [32767], dtype=np.uint16 ).tobytes() # Create network B with policy head value 20.0. net_b = net_pb2.Net() net_b.format.weights_encoding = net_pb2.Format.LINEAR16 net_b.weights.policy_heads.ip_pol_w.min_val = 20.0 net_b.weights.policy_heads.ip_pol_w.max_val = 20.0 net_b.weights.policy_heads.ip_pol_w.params = np.array( [32767], dtype=np.uint16 ).tobytes() # Wrap and perform assignment. wrapper_a = NetWrapper(net_a) wrapper_b = NetWrapper(net_b) # Replace A's policy heads with B's. wrapper_a.weights.policy_heads = wrapper_b.weights.policy_heads # Verify in-memory replacement. result_value = wrapper_a.weights.policy_heads.ip_pol_w.value assert np.isclose(result_value[0], 20.0, rtol=1e-4), ( f"Expected 20.0 (B's value), got {result_value[0]}" ) # Verify persistence (round-trip). with tempfile.NamedTemporaryFile(suffix=".pb.gz", delete=False) as tmp: tmp_path = tmp.name try: save_weights(wrapper_a, tmp_path) reloaded = load_weights(tmp_path) reloaded_value = reloaded.weights.policy_heads.ip_pol_w.value assert np.isclose(reloaded_value[0], 20.0, rtol=1e-4), ( f"After save/load: Expected 20.0, got {reloaded_value[0]}" ) finally: os.unlink(tmp_path) def test_policy_head_map_assignment() -> None: """Test assigning policy_heads when one has policy_head_map and another doesn't.""" # Create network A with NO policy_head_map. net_a = net_pb2.Net() net_a.format.weights_encoding = net_pb2.Format.LINEAR16 net_a.weights.policy_heads.ip_pol_w.min_val = 5.0 net_a.weights.policy_heads.ip_pol_w.max_val = 5.0 net_a.weights.policy_heads.ip_pol_w.params = np.array( [32767], dtype=np.uint16 ).tobytes() # Create network B WITH policy_head_map. net_b = net_pb2.Net() net_b.format.weights_encoding = net_pb2.Format.LINEAR16 # Add a policy_head_map entry with required key and value fields. policy_map = net_b.weights.policy_heads.policy_head_map.add() policy_map.key = "test_policy" policy_map.value.ip_pol_w.min_val = 15.0 policy_map.value.ip_pol_w.max_val = 15.0 policy_map.value.ip_pol_w.params = np.array( [32767], dtype=np.uint16 ).tobytes() # Wrap networks. wrapper_a = NetWrapper(net_a) wrapper_b = NetWrapper(net_b) # Replace A's policy heads with B's (which has policy_head_map). wrapper_a.weights.policy_heads = wrapper_b.weights.policy_heads # Access policy_head_map through the cache to verify consistency. assert len(wrapper_a.weights.policy_heads.policy_head_map) == 1 assert ( wrapper_a.weights.policy_heads.policy_head_map[0]._proto.key == "test_policy" ) # Verify can save without serialization errors. with tempfile.NamedTemporaryFile(suffix=".pb.gz", delete=False) as tmp: tmp_path = tmp.name try: save_weights(wrapper_a, tmp_path) # Verify round-trip preserves the policy_head_map. reloaded = load_weights(tmp_path) assert len(reloaded.weights.policy_heads.policy_head_map) == 1 assert ( reloaded.weights.policy_heads.policy_head_map[0]._proto.key == "test_policy" ) policy_value = reloaded.weights.policy_heads.policy_head_map[ 0 ].value.ip_pol_w.value assert np.isclose(policy_value[0], 15.0, rtol=1e-4) finally: os.unlink(tmp_path) def test_noop_arithmetic() -> None: """Test that 0.3*A + 0.7*A equals A (no-op operation).""" # Create network A with simple weights (value 10.0). net_a = net_pb2.Net() net_a.format.weights_encoding = net_pb2.Format.LINEAR16 net_a.weights.ip1_val_w.min_val = 10.0 net_a.weights.ip1_val_w.max_val = 10.0 net_a.weights.ip1_val_w.params = np.array( [32767], dtype=np.uint16 ).tobytes() wrapper_a = NetWrapper(net_a) # Perform no-op operation: 0.3*A + 0.7*A should equal A. result = 0.3 * wrapper_a + 0.7 * wrapper_a # Verify in-memory: result should equal A (10.0). result_value = result.weights.ip1_val_w.value expected = 10.0 assert np.isclose(result_value[0], expected, rtol=1e-4), ( f"Expected {expected}, got {result_value[0]}" ) # Verify persistence: save and reload. with tempfile.NamedTemporaryFile(suffix=".pb.gz", delete=False) as tmp: tmp_path = tmp.name try: save_weights(result, tmp_path) reloaded = load_weights(tmp_path) reloaded_value = reloaded.weights.ip1_val_w.value assert np.isclose(reloaded_value[0], expected, rtol=1e-4), ( f"After save/load: Expected {expected}, got {reloaded_value[0]}" ) finally: os.unlink(tmp_path) def test_list_item_assignment() -> None: """Test that list item assignment works for encoder blocks.""" # Create network A with encoder block containing value 10.0. net_a = net_pb2.Net() net_a.format.weights_encoding = net_pb2.Format.LINEAR16 encoder_a = net_a.weights.encoder.add() encoder_a.mha.q_w.min_val = 10.0 encoder_a.mha.q_w.max_val = 10.0 encoder_a.mha.q_w.params = np.array([32767], dtype=np.uint16).tobytes() # Create network B with encoder block containing value 20.0. net_b = net_pb2.Net() net_b.format.weights_encoding = net_pb2.Format.LINEAR16 encoder_b = net_b.weights.encoder.add() encoder_b.mha.q_w.min_val = 20.0 encoder_b.mha.q_w.max_val = 20.0 encoder_b.mha.q_w.params = np.array([32767], dtype=np.uint16).tobytes() # Wrap networks. wrapper_a = NetWrapper(net_a) wrapper_b = NetWrapper(net_b) # Assign encoder[0] from B to A. wrapper_a.weights.encoder[0] = wrapper_b.weights.encoder[0] # Verify in-memory: A's encoder[0] should now have value 20.0. result_value = wrapper_a.weights.encoder[0].mha.q_w.value assert np.isclose(result_value[0], 20.0, rtol=1e-4), ( f"Expected 20.0 (B's value), got {result_value[0]}" ) # Verify persistence. with tempfile.NamedTemporaryFile(suffix=".pb.gz", delete=False) as tmp: tmp_path = tmp.name try: save_weights(wrapper_a, tmp_path) reloaded = load_weights(tmp_path) reloaded_value = reloaded.weights.encoder[0].mha.q_w.value assert np.isclose(reloaded_value[0], 20.0, rtol=1e-4), ( f"After save/load: Expected 20.0, got {reloaded_value[0]}" ) finally: os.unlink(tmp_path) ================================================ FILE: src/lczero_training/tools/__init__.py ================================================ """Pure Python tools for weight manipulation.""" from .weights_tool import load_weights, save_weights __all__ = ["load_weights", "save_weights"] ================================================ FILE: src/lczero_training/tools/weight_codecs.py ================================================ """Encoding and decoding logic for Lc0 weight formats.""" import numpy as np from proto import net_pb2 def decode_linear16( params: bytes, min_val: float, max_val: float, shape: tuple[int, ...] ) -> np.ndarray: """Decode LINEAR16 format to float32 array.""" raw = np.frombuffer(params, dtype=np.uint16) norm = raw.astype(np.float32) / 65535.0 values = norm * max_val + (1.0 - norm) * min_val return values.reshape(shape[::-1]).transpose() def encode_linear16(arr: np.ndarray) -> tuple[bytes, float, float]: """Encode float32 array to LINEAR16 format.""" flat = arr.T.flatten().astype(np.float32) min_val = float(flat.min()) max_val = float(flat.max()) rng = max_val - min_val if rng < 1e-8: norm = np.full_like(flat, 0.5) else: norm = (flat - min_val) / rng quant = np.round(norm * 65535.0).astype(np.uint16) return quant.tobytes(), min_val, max_val def decode_float16(params: bytes, shape: tuple[int, ...]) -> np.ndarray: """Decode FLOAT16 format to float32 array.""" raw = np.frombuffer(params, dtype=np.float16) values = raw.astype(np.float32) return values.reshape(shape[::-1]).transpose() def encode_float16(arr: np.ndarray) -> tuple[bytes, float, float]: """Encode float32 array to FLOAT16 format.""" flat = arr.T.flatten().astype(np.float16) return flat.tobytes(), 0.0, 0.0 def decode_bfloat16(params: bytes, shape: tuple[int, ...]) -> np.ndarray: """Decode BFLOAT16 format to float32 array via bit manipulation.""" raw_u16 = np.frombuffer(params, dtype=np.uint16) raw_u32 = raw_u16.astype(np.uint32) << 16 values = raw_u32.view(np.float32) return values.reshape(shape[::-1]).transpose() def encode_bfloat16(arr: np.ndarray) -> tuple[bytes, float, float]: """Encode float32 array to BFLOAT16 format via bit manipulation.""" flat = arr.T.flatten().astype(np.float32) u32 = flat.view(np.uint32) u16 = (u32 >> 16).astype(np.uint16) return u16.tobytes(), 0.0, 0.0 def decode_layer( layer: net_pb2.Weights.Layer, fallback_encoding: int ) -> np.ndarray: """Decode a Layer protobuf to float32 NumPy array.""" encoding = layer.encoding if layer.encoding else fallback_encoding if not layer.dims: size = len(layer.params) // 2 shape: tuple[int, ...] = (size,) else: shape = tuple(layer.dims) if encoding == net_pb2.Weights.Layer.LINEAR16: return decode_linear16( layer.params, layer.min_val, layer.max_val, shape ) elif encoding == net_pb2.Weights.Layer.FLOAT16: return decode_float16(layer.params, shape) elif encoding == net_pb2.Weights.Layer.BFLOAT16: return decode_bfloat16(layer.params, shape) else: raise ValueError(f"Unknown encoding: {encoding}") def encode_layer(arr: np.ndarray, encoding: int) -> tuple[bytes, float, float]: """Encode a float32 NumPy array to Layer format.""" if encoding == net_pb2.Weights.Layer.LINEAR16: return encode_linear16(arr) elif encoding == net_pb2.Weights.Layer.FLOAT16: return encode_float16(arr) elif encoding == net_pb2.Weights.Layer.BFLOAT16: return encode_bfloat16(arr) else: raise ValueError(f"Unknown encoding: {encoding}") ================================================ FILE: src/lczero_training/tools/weight_wrappers.py ================================================ """Wrapper classes for pythonic access to Lc0 weight protobufs.""" import gzip from typing import Any, Iterator import numpy as np from google.protobuf.message import Message from proto import net_pb2 from . import weight_codecs class LayerWrapper: """Wraps a net_pb2.Weights.Layer with lazy float32 decoding.""" __slots__ = ("_proto", "_fallback_encoding", "_cached_array", "_modified") def __init__( self, proto: net_pb2.Weights.Layer, fallback_encoding: int ) -> None: object.__setattr__(self, "_proto", proto) object.__setattr__(self, "_fallback_encoding", fallback_encoding) object.__setattr__(self, "_cached_array", None) object.__setattr__(self, "_modified", False) _proto: net_pb2.Weights.Layer _fallback_encoding: int _cached_array: np.ndarray | None _modified: bool @property def value(self) -> np.ndarray: """Decode to float32 on first access.""" if self._cached_array is None: decoded = weight_codecs.decode_layer( self._proto, self._fallback_encoding ) object.__setattr__(self, "_cached_array", decoded) assert self._cached_array is not None return self._cached_array @value.setter def value(self, arr: np.ndarray) -> None: """Set new array value, mark as modified.""" object.__setattr__(self, "_cached_array", arr.astype(np.float32)) object.__setattr__(self, "_modified", True) def commit(self, encoding: int) -> None: """Re-encode array to proto if modified.""" if self._modified and self._cached_array is not None: params, min_val, max_val = weight_codecs.encode_layer( self._cached_array, encoding ) self._proto.params = params self._proto.min_val = min_val self._proto.max_val = max_val self._proto.encoding = encoding # type: ignore[assignment] del self._proto.dims[:] self._proto.dims.extend(self._cached_array.shape) object.__setattr__(self, "_modified", False) def __add__(self, other: "LayerWrapper") -> "LayerWrapper": if not isinstance(other, LayerWrapper): raise TypeError("Can only add LayerWrapper to LayerWrapper") result = LayerWrapper(net_pb2.Weights.Layer(), self._fallback_encoding) result.value = self.value + other.value return result def __sub__(self, other: "LayerWrapper") -> "LayerWrapper": if not isinstance(other, LayerWrapper): raise TypeError("Can only subtract LayerWrapper from LayerWrapper") result = LayerWrapper(net_pb2.Weights.Layer(), self._fallback_encoding) result.value = self.value - other.value return result def __mul__(self, scalar: float) -> "LayerWrapper": if not isinstance(scalar, (int, float)): raise TypeError("Can only multiply LayerWrapper by scalar") result = LayerWrapper(net_pb2.Weights.Layer(), self._fallback_encoding) result.value = self.value * scalar return result def __rmul__(self, scalar: float) -> "LayerWrapper": return self.__mul__(scalar) class ListWrapper: """Wraps protobuf repeated fields.""" __slots__ = ("_proto_list", "_parent", "_item_cache") def __init__(self, proto_list: Any, parent: "NetWrapper") -> None: self._proto_list = proto_list self._parent = parent self._item_cache: dict[int, Any] = {} _proto_list: Any _parent: "NetWrapper" _item_cache: dict[int, Any] def __len__(self) -> int: return len(self._proto_list) def __getitem__(self, idx: int) -> Any: if idx not in self._item_cache: item_proto = self._proto_list[idx] self._item_cache[idx] = self._parent._wrap_field(item_proto) return self._item_cache[idx] def __setitem__(self, idx: int, value: Any) -> None: """Support item assignment for list elements.""" if isinstance(value, NetWrapper): dest_proto = self._proto_list[idx] dest_proto.CopyFrom(value._proto) # Create new wrapper for destination proto to maintain cache consistency. self._item_cache[idx] = NetWrapper( dest_proto, self._parent._fallback_encoding ) elif isinstance(value, LayerWrapper): dest_proto = self._proto_list[idx] dest_proto.CopyFrom(value._proto) # Create new wrapper for destination proto to maintain cache consistency. self._item_cache[idx] = LayerWrapper( dest_proto, self._parent._fallback_encoding ) else: self._proto_list[idx] = value if idx in self._item_cache: del self._item_cache[idx] def __iter__(self) -> Iterator[Any]: for i in range(len(self)): yield self[i] class NetWrapper: """Wraps net_pb2.Net or nested Message types.""" __slots__ = ("_proto", "_fallback_encoding", "_attr_cache") def __init__( self, proto_msg: Message, fallback_encoding: int | None = None ) -> None: object.__setattr__(self, "_proto", proto_msg) object.__setattr__( self, "_fallback_encoding", fallback_encoding or self._detect_encoding(), ) object.__setattr__(self, "_attr_cache", {}) _proto: Message _fallback_encoding: int _attr_cache: dict[str, Any] def _detect_encoding(self) -> int: """Extract encoding from net.format.weights_encoding.""" if hasattr(self._proto, "format") and self._proto.HasField("format"): return self._proto.format.weights_encoding return net_pb2.Weights.Layer.LINEAR16 def __getattr__(self, name: str) -> Any: if name.startswith("_"): return object.__getattribute__(self, name) if name in self._attr_cache: return self._attr_cache[name] if not hasattr(self._proto, name): raise AttributeError( f"{type(self._proto).__name__} has no field '{name}'" ) value = getattr(self._proto, name) wrapped = self._wrap_field(value) self._attr_cache[name] = wrapped return wrapped def _wrap_field(self, value: Any) -> Any: """Determine wrapper type based on proto field.""" if isinstance(value, net_pb2.Weights.Layer): return LayerWrapper(value, self._fallback_encoding) elif isinstance(value, Message): return NetWrapper(value, self._fallback_encoding) elif hasattr(value, "__len__") and not isinstance(value, (str, bytes)): return ListWrapper(value, self) else: return value def __setattr__(self, name: str, value: Any) -> None: if name.startswith("_"): object.__setattr__(self, name, value) return if isinstance(value, NetWrapper): dest_proto = getattr(self._proto, name) dest_proto.CopyFrom(value._proto) # Create new wrapper for destination proto to maintain cache consistency. self._attr_cache[name] = NetWrapper( dest_proto, self._fallback_encoding ) elif isinstance(value, LayerWrapper): dest_proto = getattr(self._proto, name) dest_proto.CopyFrom(value._proto) # Create new wrapper for destination proto to maintain cache consistency. self._attr_cache[name] = LayerWrapper( dest_proto, self._fallback_encoding ) else: setattr(self._proto, name, value) def save(self, path: str, encoding: int | None = None) -> None: """Save to file, committing all modified layers.""" if encoding is None: encoding = net_pb2.Weights.Layer.FLOAT16 self._commit_all(encoding) serialized = self._proto.SerializePartialToString() if path.endswith(".gz"): with gzip.open(path, "wb") as f: f.write(serialized) else: with open(path, "wb") as f: f.write(serialized) def _commit_all(self, encoding: int) -> None: """Recursively commit all modified LayerWrappers.""" for cached_value in self._attr_cache.values(): if isinstance(cached_value, LayerWrapper): cached_value.commit(encoding) elif isinstance(cached_value, NetWrapper): cached_value._commit_all(encoding) elif isinstance(cached_value, ListWrapper): for item in cached_value: if isinstance(item, NetWrapper): item._commit_all(encoding) elif isinstance(item, LayerWrapper): item.commit(encoding) def __add__(self, other: "NetWrapper") -> "NetWrapper": """Element-wise addition of two networks.""" if not isinstance(other, NetWrapper): raise TypeError("Can only add NetWrapper to NetWrapper") result_proto = type(self._proto)() result_proto.CopyFrom(self._proto) result = NetWrapper(result_proto, self._fallback_encoding) result._add_weights(self, other) return result def __sub__(self, other: "NetWrapper") -> "NetWrapper": """Element-wise subtraction of two networks.""" if not isinstance(other, NetWrapper): raise TypeError("Can only subtract NetWrapper from NetWrapper") result_proto = type(self._proto)() result_proto.CopyFrom(self._proto) result = NetWrapper(result_proto, self._fallback_encoding) result._sub_weights(self, other) return result def __mul__(self, scalar: float) -> "NetWrapper": """Scalar multiplication.""" if not isinstance(scalar, (int, float)): raise TypeError("Can only multiply NetWrapper by scalar") result_proto = type(self._proto)() result_proto.CopyFrom(self._proto) result = NetWrapper(result_proto, self._fallback_encoding) result._mul_weights(self, scalar) return result def __rmul__(self, scalar: float) -> "NetWrapper": return self.__mul__(scalar) def _add_weights(self, lhs: "NetWrapper", rhs: "NetWrapper") -> None: """Recursively add weights from lhs and rhs into self.""" for field_desc in lhs._proto.DESCRIPTOR.fields: field_name = field_desc.name if not hasattr(lhs._proto, field_name): continue # Skip optional fields that are not set in both inputs. if not field_desc.is_required and not field_desc.is_repeated: lhs_has = lhs._proto.HasField(field_name) rhs_has = rhs._proto.HasField(field_name) if not lhs_has or not rhs_has: continue lhs_val = getattr(lhs, field_name) rhs_val = getattr(rhs, field_name) if isinstance(lhs_val, LayerWrapper): self_layer = getattr(self, field_name) self_layer.value = lhs_val.value + rhs_val.value elif isinstance(lhs_val, NetWrapper): self_wrapper = getattr(self, field_name) self_wrapper._add_weights(lhs_val, rhs_val) elif isinstance(lhs_val, ListWrapper): self_list = getattr(self, field_name) # Only process indices that exist in both lists. min_len = min(len(lhs_val), len(rhs_val)) for i in range(min_len): if isinstance(lhs_val[i], (NetWrapper, LayerWrapper)): if isinstance(lhs_val[i], LayerWrapper): self_list[i].value = ( lhs_val[i].value + rhs_val[i].value ) else: self_list[i]._add_weights(lhs_val[i], rhs_val[i]) # Truncate to min_len to avoid incomplete entries. del self_list._proto_list[min_len:] def _sub_weights(self, lhs: "NetWrapper", rhs: "NetWrapper") -> None: """Recursively subtract weights rhs from lhs into self.""" for field_desc in lhs._proto.DESCRIPTOR.fields: field_name = field_desc.name if not hasattr(lhs._proto, field_name): continue # Skip optional fields that are not set in both inputs. if not field_desc.is_required and not field_desc.is_repeated: lhs_has = lhs._proto.HasField(field_name) rhs_has = rhs._proto.HasField(field_name) if not lhs_has or not rhs_has: continue lhs_val = getattr(lhs, field_name) rhs_val = getattr(rhs, field_name) if isinstance(lhs_val, LayerWrapper): self_layer = getattr(self, field_name) self_layer.value = lhs_val.value - rhs_val.value elif isinstance(lhs_val, NetWrapper): self_wrapper = getattr(self, field_name) self_wrapper._sub_weights(lhs_val, rhs_val) elif isinstance(lhs_val, ListWrapper): self_list = getattr(self, field_name) # Only process indices that exist in both lists. min_len = min(len(lhs_val), len(rhs_val)) for i in range(min_len): if isinstance(lhs_val[i], (NetWrapper, LayerWrapper)): if isinstance(lhs_val[i], LayerWrapper): self_list[i].value = ( lhs_val[i].value - rhs_val[i].value ) else: self_list[i]._sub_weights(lhs_val[i], rhs_val[i]) # Truncate to min_len to avoid incomplete entries. del self_list._proto_list[min_len:] def _mul_weights(self, source: "NetWrapper", scalar: float) -> None: """Recursively multiply all weights by scalar.""" for field_desc in source._proto.DESCRIPTOR.fields: field_name = field_desc.name if not hasattr(source._proto, field_name): continue # Skip unset optional fields to avoid creating them in output. if ( not field_desc.is_required and not field_desc.is_repeated and not source._proto.HasField(field_name) ): continue src_val = getattr(source, field_name) if isinstance(src_val, LayerWrapper): self_layer = getattr(self, field_name) self_layer.value = src_val.value * scalar elif isinstance(src_val, NetWrapper): self_wrapper = getattr(self, field_name) self_wrapper._mul_weights(src_val, scalar) elif isinstance(src_val, ListWrapper): self_list = getattr(self, field_name) for i in range(len(src_val)): if isinstance(src_val[i], (NetWrapper, LayerWrapper)): if isinstance(src_val[i], LayerWrapper): self_list[i].value = src_val[i].value * scalar else: self_list[i]._mul_weights(src_val[i], scalar) ================================================ FILE: src/lczero_training/tools/weights_tool.py ================================================ """Main API for loading and saving Lc0 weight files.""" import gzip from proto import net_pb2 from .weight_wrappers import NetWrapper def load_weights(path: str) -> NetWrapper: """Load Lc0 weights file (.pb or .pb.gz).""" if path.endswith(".gz"): with gzip.open(path, "rb") as f: contents = f.read() else: with open(path, "rb") as f: contents = f.read() net = net_pb2.Net() net.ParseFromString(contents) return NetWrapper(net) def save_weights( wrapper: NetWrapper, path: str, encoding: str = "FLOAT16" ) -> None: """Save weights to file.""" encoding_map = { "LINEAR16": net_pb2.Weights.Layer.LINEAR16, "FLOAT16": net_pb2.Weights.Layer.FLOAT16, "BFLOAT16": net_pb2.Weights.Layer.BFLOAT16, } enc_value = encoding_map[encoding.upper()] wrapper.save(path, enc_value) ================================================ FILE: src/lczero_training/training/__init__.py ================================================ """Training package for Leela Chess Zero.""" ================================================ FILE: src/lczero_training/training/backfill_metrics.py ================================================ """Backfill metrics for existing checkpoints.""" import logging from typing import Any from flax import nnx from google.protobuf import text_format from lczero_training.daemon.metrics import ( evaluate_batch, load_batch_from_npz, ) from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel from lczero_training.training.migrate_checkpoint import ( Migration, get_checkpoint_steps, load_checkpoint, load_migration_rules, ) from lczero_training.training.state import TrainingState from lczero_training.training.tensorboard import TensorboardLogger from proto.root_config_pb2 import RootConfig logger = logging.getLogger(__name__) def _load_config(config_path: str) -> RootConfig: """Load RootConfig from textproto file.""" config = RootConfig() with open(config_path, "r") as f: text_format.Parse(f.read(), config) return config def _validate_and_get_metrics( root_config: RootConfig, metric_names: list[str] ) -> dict[str, Any]: """Validate metrics exist and are NPZ type.""" if not root_config.HasField("metrics"): raise ValueError("No metrics configuration found in root config") metric_configs = { mc.name: mc for mc in root_config.metrics.metric if mc.name in metric_names and mc.HasField("npz_filename") } non_npz = [ mc.name for mc in root_config.metrics.metric if mc.name in metric_names and not mc.HasField("npz_filename") ] if non_npz: raise ValueError(f"Non-NPZ metrics: {', '.join(non_npz)}") missing = set(metric_names) - set(metric_configs.keys()) if missing: raise ValueError(f"Metrics not found: {', '.join(sorted(missing))}") return metric_configs def _load_and_migrate_checkpoint( checkpoint_path: str, step: int, template: TrainingState, rules: list[tuple], ) -> TrainingState: """Load checkpoint and apply migration if rules provided.""" state, _ = load_checkpoint(checkpoint_path, step) return Migration(state, template).run(rules) def backfill_metrics( config_path: str, metric_names: list[str], min_step: int | None = None, max_step: int | None = None, migration_config_path: str | None = None, ) -> None: """Backfill metrics for existing checkpoints. Args: config_path: Path to the RootConfig textproto file. metric_names: Names of metrics to backfill (must be NPZ metrics). min_step: Minimum checkpoint step (inclusive), or None. max_step: Maximum checkpoint step (inclusive), or None. migration_config_path: Path to CheckpointMigrationConfig file, or None. Raises: ValueError: If any metric is not an NPZ metric or doesn't exist. """ config = _load_config(config_path) metric_configs = _validate_and_get_metrics(config, metric_names) # Load batches and create loggers. batches = { name: load_batch_from_npz(mc.npz_filename) for name, mc in metric_configs.items() } loggers = { name: TensorboardLogger(f"{config.metrics.tensorboard_path}/{name}") for name in metric_configs } # Initialize model components. loss_fn = LczeroLoss(config=config.training.losses) graphdef, _ = nnx.split( LczeroModel(config=config.model, rngs=nnx.Rngs(params=42)) ) # Prepare migration if needed. rules = load_migration_rules(migration_config_path) template = TrainingState.new_from_config(config.model, config.training) # Get and process checkpoints. steps = get_checkpoint_steps( config.training.checkpoint.path, min_step, max_step ) logger.info(f"Processing {len(steps)} checkpoints") if not steps: logger.warning("No checkpoints found in range") return try: for step in steps: logger.info(f"Step {step}") state = _load_and_migrate_checkpoint( config.training.checkpoint.path, step, template, rules ) for name, mc in metric_configs.items(): metrics = evaluate_batch( batches[name], state.jit_state, graphdef, loss_fn, mc.use_swa_model, ) loggers[name].log(step, metrics) logger.info(f" {name}: loss={metrics['loss']:.6f}") finally: for tb_logger in loggers.values(): tb_logger.close() logger.info("Backfill complete") ================================================ FILE: src/lczero_training/training/dataloader_probe.py ================================================ """Utilities for exercising the training data loader.""" import logging import time from contextlib import suppress from pathlib import Path from typing import Optional import numpy as np from google.protobuf import text_format from lczero_training.dataloader import DataLoader, make_dataloader from proto.root_config_pb2 import RootConfig logger = logging.getLogger(__name__) def _stop_loader(loader: DataLoader) -> None: with suppress(Exception): loader.stop() def _store_batches(path: str, batches: list) -> None: output = Path(path) if output.parent: output.parent.mkdir(parents=True, exist_ok=True) logger.info("Writing %d batches to %s", len(batches), output) container = np.empty(len(batches), dtype=object) container[:] = batches np.savez(output, batches=container) def probe_dataloader( config_filename: str, num_batches: int, npz_output: Optional[str] = None ) -> None: """Measure latency and throughput for the configured data loader. Args: config_filename: Path to the root configuration proto file. num_batches: Total number of batches to fetch from the loader. npz_output: Optional path to store fetched batches as an .npz archive. """ if num_batches < 1: raise ValueError("num_batches must be at least 1") config = RootConfig() logger.info("Reading configuration from proto file") with open(config_filename, "r") as config_file: text_format.Parse(config_file.read(), config) logger.info("Creating data loader") loader = make_dataloader(config.data_loader) collected_batches: list = [] collect_enabled = npz_output is not None first_batch_time = 0.0 remaining_batches = num_batches - 1 try: logger.info("Fetching first batch") start_time = time.perf_counter() first_batch = loader.get_next() if collect_enabled: collected_batches.append(first_batch) first_batch_time = time.perf_counter() - start_time logger.info("Time to first batch: %.3f seconds", first_batch_time) if remaining_batches <= 0: logger.info("Only fetched first batch; skipping throughput") return logger.info( "Fetching %d additional batches for throughput measurement", remaining_batches, ) throughput_start = time.perf_counter() for _ in range(remaining_batches): batch = loader.get_next() if collect_enabled: collected_batches.append(batch) throughput_duration = time.perf_counter() - throughput_start if throughput_duration <= 0: logger.warning("Measured non-positive duration; skipping rate") return batches_per_second = remaining_batches / throughput_duration logger.info( "Throughput excluding first batch: %.2f batches/second", batches_per_second, ) logger.info( "Total time excluding first batch: %.3f seconds", throughput_duration, ) finally: if collect_enabled and npz_output is not None and collected_batches: _store_batches(npz_output, collected_batches) _stop_loader(loader) ================================================ FILE: src/lczero_training/training/describe.py ================================================ import logging import sys from pathlib import PurePosixPath import jax import jax.numpy as jnp import orbax.checkpoint as ocp from flax import nnx from google.protobuf import text_format from lczero_training.training.state import TrainingState from proto.root_config_pb2 import RootConfig logger = logging.getLogger(__name__) def describe( config_filename: str, shapes: bool = False, values: bool = False, weight_paths: bool = False, ) -> None: config = RootConfig() logger.info("Reading configuration from proto file") with open(config_filename, "r") as f: text_format.Parse(f.read(), config) if config.training.checkpoint.path is None: logger.error("Checkpoint path must be set in the configuration.") sys.exit(1) checkpoint_mgr = ocp.CheckpointManager( config.training.checkpoint.path, options=ocp.CheckpointManagerOptions( create=True, ), ) logger.info("Creating state from configuration") empty_state = TrainingState.new_from_config( model_config=config.model, training_config=config.training, ) logger.info("Restoring checkpoint") training_state = checkpoint_mgr.restore( None, args=ocp.args.PyTreeRestore(empty_state) ) logger.info("Restored checkpoint") assert isinstance(training_state, TrainingState) if values: logger.info("Dumping training state values") print("Training state:") print(training_state) if shapes: logger.info("Extracting training state shapes") shapes = jax.tree.map(jnp.shape, training_state) print("Training state shapes:") print(shapes) if weight_paths: paths = [] def _collect(path: tuple[object, ...], _: nnx.Variable) -> bool: paths.append(str(PurePosixPath(*map(str, path)))) return False nnx.map_state(_collect, training_state.jit_state.model_state) for p in sorted( paths, key=lambda p: tuple( int(c) if c.isdigit() else c for c in p.split("/") ), ): print(p) ================================================ FILE: src/lczero_training/training/eval.py ================================================ # Description: Evaluation script for comparing model outputs and calculating losses. # # This script provides functionalities to evaluate a trained model by processing # data samples, calculating losses, and comparing outputs against an ONNX model. # It supports dumping tensors and results to various formats for analysis. import json import logging import math import shelve import sys from dataclasses import dataclass from datetime import datetime from typing import ( Any, Callable, Dict, Generator, Optional, Sequence, TextIO, Tuple, cast, ) import jax import jax.numpy as jnp import numpy as np import orbax.checkpoint as ocp from flax import nnx from google.protobuf import text_format from lczero_training.dataloader import ( DataLoader, make_dataloader, ) from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel, ModelPrediction from lczero_training.training.state import TrainingSample, TrainingState from proto import data_loader_config_pb2 from proto.root_config_pb2 import RootConfig logger = logging.getLogger(__name__) RELATIVE_DIFFERENCE_EPSILON = 1e-4 HEADS = ("wdl", "policy", "movesleft") @dataclass class DiffRecord: """Stores information about the difference between JAX and ONNX outputs.""" batch: int sample: int index: Tuple[int, ...] diff: float jax_value: float onnx_value: float def _tensor_to_list(obj: Any) -> Any: """Recursively converts JAX/Numpy arrays to Python lists for serialization.""" if hasattr(obj, "tolist"): return obj.tolist() if isinstance(obj, dict): return {key: _tensor_to_list(value) for key, value in obj.items()} return obj # --- Diff statistics helpers --- def _bin_counts(values: np.ndarray) -> Dict[str, Any]: """Counts values into bins based on powers of 2.""" flat = np.asarray(values).ravel() zero_count = int(np.count_nonzero(flat == 0.0)) non_zero = flat[flat != 0.0] bins: Dict[int, int] = {} if non_zero.size > 0: exponents = np.floor(np.log2(non_zero)).astype(int) unique, counts = np.unique(exponents, return_counts=True) bins = {int(exp): int(count) for exp, count in zip(unique, counts)} return {"zero": zero_count, "bins": bins} def _format_bound(value: float) -> str: """Formats a float for display in statistics.""" if value >= 1 and math.isclose(value, round(value)): return str(int(round(value))) return f"{value:.6g}" def _format_stats(stats: Dict[str, Any]) -> str: """Formats bin statistics into a readable string.""" lines = [f" zero={stats['zero']}"] for exponent in sorted(stats["bins"].keys(), reverse=True): lower = 2**exponent upper = 2 ** (exponent + 1) lines.append( f" [{_format_bound(lower)}; {_format_bound(upper)})=" f"{stats['bins'][exponent]}" ) return "\n".join(lines) def _collect_diff_statistics( jax_output: np.ndarray, onnx_output: np.ndarray ) -> Tuple[np.ndarray, Dict[str, Any], Optional[Dict[str, Any]]]: """Collects absolute and relative difference statistics.""" abs_diff = np.abs(jax_output - onnx_output) abs_stats = _bin_counts(abs_diff) mask = np.abs(jax_output) >= RELATIVE_DIFFERENCE_EPSILON if not np.any(mask): return abs_diff, abs_stats, None rel_diff = np.zeros_like(abs_diff, dtype=np.float64) np.divide(abs_diff, np.abs(jax_output), out=rel_diff, where=mask) rel_stats = _bin_counts(rel_diff[mask]) return abs_diff, abs_stats, rel_stats class Dumper: """Handles dumping of evaluation artifacts.""" def __init__( self, to_stdout: bool, to_file: Optional[str], to_shelve: Optional[str], to_json: Optional[str], ): self.to_stdout = to_stdout self.shelve_path = to_shelve self.json_path = to_json self.file_handle: Optional[TextIO] = ( open(to_file, "w") if to_file else None ) def dump_tensors(self, tensors: dict, prefix: str) -> None: """Dumps tensors to stdout or a text file.""" if not self.to_stdout and not self.file_handle: return lines = [f"# === {prefix} TENSORS ==="] for name, tensor in tensors.items(): lines.append(f"{name} = {str(_tensor_to_list(tensor))}") lines.append("") output_text = "\n".join(lines) if self.to_stdout: print(output_text) if self.file_handle: self.file_handle.write(output_text) self.file_handle.flush() def dump_structured(self, batch: dict, outputs: dict, losses: dict) -> None: """Dumps results to structured formats like JSON or shelve.""" if not self.shelve_path and not self.json_path: return all_data = { **_tensor_to_list(batch), **_tensor_to_list(outputs), **_tensor_to_list(losses), } key = f"sample-{datetime.now().strftime('%Y%m%d-%H%M%S')}" if self.shelve_path: self._dump_to_shelve(key, all_data) if self.json_path: self._dump_to_json(key, all_data) def _dump_to_shelve(self, key: str, data: dict) -> None: assert self.shelve_path is not None with shelve.open(self.shelve_path) as db: db[key] = data logger.info("Dumped data to shelve with key: %s", key) def _dump_to_json(self, key: str, data: dict) -> None: assert self.json_path is not None try: with open(self.json_path, "r") as f: json_data = json.load(f) except (FileNotFoundError, json.JSONDecodeError): json_data = {} json_data[key] = data with open(self.json_path, "w") as f: json.dump(json_data, f, indent=2) logger.info("Dumped data to JSON with key: %s", key) def close(self) -> None: if self.file_handle: self.file_handle.close() class OnnxComparator: """Handles comparison of JAX model outputs with ONNX model outputs.""" def __init__(self, onnx_model_path: str): try: import onnxruntime as ort except ImportError as exc: raise RuntimeError( "onnxruntime is required for ONNX comparison." ) from exc self.session = ort.InferenceSession(onnx_model_path) inputs = self.session.get_inputs() if not inputs: raise ValueError("ONNX model must define at least one input.") self.input_name = inputs[0].name logger.info("Loaded ONNX model for comparison from %s", onnx_model_path) self.head_mapping_logged = False self.worst_records: Dict[str, Optional[DiffRecord]] = { head: None for head in HEADS } self.onnx_outputs: Dict[str, jax.Array] = {} def compare( self, jax_outputs: Dict[str, jax.Array], onnx_inputs_np: np.ndarray, sample_index: int, ) -> None: """Runs comparison for a single sample.""" raw_onnx_outputs = self.session.run( None, {self.input_name: onnx_inputs_np} ) if len(raw_onnx_outputs) != 3: raise ValueError( "Expected three outputs (wdl, policy, movesleft) from ONNX model." ) jax_outputs_np = {k: np.asarray(v) for k, v in jax_outputs.items()} aligned_onnx, head_indices = self._align_onnx_outputs( jax_outputs_np, raw_onnx_outputs ) if not self.head_mapping_logged: order = ", ".join(f"{h}=output[{head_indices[h]}]" for h in HEADS) logger.info("Aligned ONNX outputs to heads as: %s", order) self.head_mapping_logged = True self.onnx_outputs = { "onnx_value_pred": jnp.asarray(aligned_onnx["wdl"]), "onnx_policy_pred": jnp.asarray(aligned_onnx["policy"]), "onnx_movesleft_pred": jnp.asarray(aligned_onnx["movesleft"]), } for head in HEADS: record = self._log_diff_stats( head, jax_outputs_np[head], aligned_onnx[head], sample_index ) current = self.worst_records[head] if current is None or record.diff > current.diff: self.worst_records[head] = record def log_summary(self) -> None: """Logs the worst difference found for each head.""" for head in HEADS: record = self.worst_records[head] if record: logger.info( "Worst ONNX abs diff for %s head: " "batch=%d sample=%d index=%s diff=%0.6g " "jax=%0.6g onnx=%0.6g", head, record.batch, record.sample, record.index, record.diff, record.jax_value, record.onnx_value, ) def _log_diff_stats( self, head: str, jax_output: np.ndarray, onnx_output: np.ndarray, sample_index: int, ) -> DiffRecord: if jax_output.shape != onnx_output.shape: raise ValueError( f"Shape mismatch for {head} head: " f"JAX {jax_output.shape} vs ONNX {onnx_output.shape}." ) abs_diff, abs_stats, rel_stats = _collect_diff_statistics( jax_output, onnx_output ) logger.info( "Batch %d %s head ONNX abs diff stats:\n%s", sample_index, head, _format_stats(abs_stats), ) if rel_stats: logger.info( "Batch %d %s head ONNX rel diff stats:\n%s", sample_index, head, _format_stats(rel_stats), ) else: logger.info( "Batch %d %s head ONNX rel diff stats: skipped (all |jax| < %.1e)", sample_index, head, RELATIVE_DIFFERENCE_EPSILON, ) max_loc = np.unravel_index(int(np.argmax(abs_diff)), abs_diff.shape) return DiffRecord( batch=sample_index, sample=int(max_loc[0]) if max_loc else 0, index=tuple(int(i) for i in max_loc), diff=float(abs_diff[max_loc]), jax_value=float(jax_output[max_loc]), onnx_value=float(onnx_output[max_loc]), ) def _align_onnx_outputs( self, jax_outputs: Dict[str, np.ndarray], onnx_outputs: Sequence[np.ndarray], ) -> Tuple[Dict[str, np.ndarray], Dict[str, int]]: """Matches ONNX outputs to heads regardless of ordering differences.""" remaining = [(i, np.asarray(o)) for i, o in enumerate(onnx_outputs)] aligned: Dict[str, np.ndarray] = {} indices: Dict[str, int] = {} def pop_match( predicate: Callable[[np.ndarray], bool], ) -> Optional[Tuple[int, np.ndarray]]: for i, candidate in enumerate(remaining): if predicate(candidate[1]): return remaining.pop(i) return None for head in HEADS: shape = jax_outputs[head].shape match = pop_match(lambda arr: arr.shape == shape) if match: idx, array = match aligned[head], indices[head] = array, idx for head in HEADS: if head in aligned: continue size = jax_outputs[head].size match = pop_match(lambda arr: arr.size == size) if match: idx, array = match aligned[head], indices[head] = array, idx if len(aligned) != len(HEADS): rem_shapes = [arr.shape for _, arr in remaining] raise ValueError( "Could not align ONNX outputs with JAX outputs. " f"Aligned: {list(aligned.keys())}; Unmatched shapes: {rem_shapes}" ) for head in HEADS: aligned[head] = self._reshape_output( aligned[head], jax_outputs[head].shape, head ) return aligned, indices def _reshape_output( self, array: np.ndarray, target_shape: Tuple[int, ...], head: str ) -> np.ndarray: if array.shape == target_shape: return array if array.size != int(np.prod(target_shape)): raise ValueError( f"Cannot reshape ONNX output for {head}: source shape " f"{array.shape}, target shape {target_shape}." ) try: return np.reshape(array, target_shape) except ValueError as exc: raise ValueError( f"Failed to reshape ONNX output for {head} to {target_shape}." ) from exc class Evaluation: """Orchestrates the model evaluation process.""" def __init__(self, loss_fn: LczeroLoss): self.loss_fn = loss_fn def run( self, model: LczeroModel, datagen: Generator[tuple[np.ndarray, ...], None, None], num_samples: int, dumper: Dumper, onnx_comparator: Optional[OnnxComparator], softmax_jax_wdl: bool, ) -> None: loss_vfn = jax.vmap(self._loss_for_grad, in_axes=(None, 0), out_axes=0) model_output_vfn = jax.vmap( self._model_for_output, in_axes=(None, 0), out_axes=0 ) for i in range(num_samples): logger.info("Processing sample %d/%d", i, num_samples) self._process_sample( model, datagen, i, dumper, onnx_comparator, loss_vfn, model_output_vfn, softmax_jax_wdl, ) logger.info("Sample %d complete", i) def _process_sample( self, model: LczeroModel, datagen: Generator[tuple[np.ndarray, ...], None, None], sample_idx: int, dumper: Dumper, onnx_comparator: Optional[OnnxComparator], loss_vfn: Callable[ [LczeroModel, TrainingSample], Tuple[jax.Array, Dict[str, jax.Array]], ], model_output_vfn: Callable[ [LczeroModel, jax.Array], ModelPrediction, ], softmax_jax_wdl: bool, ) -> None: batch_tuple = next(datagen) logger.info("Fetched batch from dataloader") # DataLoader now returns tuple: (inputs, probabilities, values) batch = { "inputs": cast(jax.Array, jnp.asarray(batch_tuple[0])), "probabilities": cast(jax.Array, jnp.asarray(batch_tuple[1])), "values": cast(jax.Array, jnp.asarray(batch_tuple[2])), } dumper.dump_tensors(batch, "INPUT") predictions = model_output_vfn(model, cast(jax.Array, batch["inputs"])) value_preds = predictions.value policy_preds = predictions.policy movesleft_preds = predictions.movesleft # Flatten all head outputs for dumping outputs = {} for name, pred_tuple in value_preds.items(): pred = pred_tuple[0] if softmax_jax_wdl: pred = jax.nn.softmax(pred, axis=-1) outputs[f"value_pred/{name}"] = pred for name, pred in policy_preds.items(): outputs[f"policy_pred/{name}"] = pred for name, pred in movesleft_preds.items(): outputs[f"movesleft_pred/{name}"] = pred if onnx_comparator: # Compare only legacy heads jax_outputs_for_onnx = { "wdl": value_preds["winner"][0], "policy": policy_preds["vanilla"], "movesleft": movesleft_preds["main"], } onnx_inputs_np = np.asarray(batch["inputs"]).copy() onnx_inputs_np[:, 109, ...] *= 99 onnx_comparator.compare( jax_outputs_for_onnx, onnx_inputs_np, sample_idx ) outputs.update(onnx_comparator.onnx_outputs) dumper.dump_tensors(outputs, "OUTPUT") # Convert batch dict to TrainingSample for loss function batch_sample = TrainingSample( inputs=batch["inputs"], probabilities=batch["probabilities"], values=batch["values"], ) per_sample_loss, unweighted_losses = loss_vfn(model, batch_sample) losses = { "per_sample_data_loss": per_sample_loss, "unweighted_losses": unweighted_losses, } dumper.dump_tensors(losses, "LOSSES") dumper.dump_structured(batch, outputs, losses) def _loss_for_grad( self, model_arg: LczeroModel, sample_arg: TrainingSample ) -> Tuple[jax.Array, Dict[str, jax.Array]]: return self.loss_fn(model_arg, sample_arg) @staticmethod def _model_for_output( model_arg: LczeroModel, inputs_arg: jax.Array ) -> ModelPrediction: return model_arg(inputs_arg) def from_dataloader( loader: DataLoader, ) -> Generator[tuple[np.ndarray, ...], None, None]: """Infinetely yields batches from a DataLoader.""" while True: yield loader.get_next() def _load_model_from_checkpoint(config: RootConfig) -> LczeroModel: """Loads a model from the latest checkpoint.""" if not config.training.checkpoint.path: logger.error("Checkpoint path must be set in the configuration.") sys.exit(1) mgr = ocp.CheckpointManager( config.training.checkpoint.path, options=ocp.CheckpointManagerOptions(create=True), ) state = TrainingState.new_from_config(config.model, config.training) restored_state = mgr.restore( mgr.latest_step(), args=ocp.args.PyTreeRestore(state) ) logger.info("Restored checkpoint from %s", config.training.checkpoint.path) assert isinstance(restored_state, TrainingState) model_graph, _ = nnx.split( LczeroModel(config.model, rngs=nnx.Rngs(params=42)) ) return nnx.merge(model_graph, restored_state.jit_state.model_state) def _get_dataloader_config( config: RootConfig, batch_size_override: Optional[int] ) -> data_loader_config_pb2.DataLoaderConfig: """Gets the dataloader config, overriding batch size if specified.""" dl_config = config.data_loader if batch_size_override is None: return dl_config for stage in dl_config.stage: if stage.HasField("tensor_generator"): stage.tensor_generator.batch_size = batch_size_override logger.info("Overriding batch size to %d", batch_size_override) return dl_config raise ValueError( "tensor_generator stage is required to override batch size" ) def eval( config_filename: str, num_samples: Optional[int] = None, batch_size_override: Optional[int] = None, dump_to_stdout: bool = False, dump_to_file: Optional[str] = None, dump_to_shelve: Optional[str] = None, dump_to_json: Optional[str] = None, onnx_model: Optional[str] = None, softmax_jax_wdl: bool = True, ) -> None: """Main function to run the evaluation.""" jnp.set_printoptions(threshold=sys.maxsize, suppress=False) config = RootConfig() logger.info("Reading configuration from: %s", config_filename) with open(config_filename, "r") as f: text_format.Parse(f.read(), config) model = _load_model_from_checkpoint(config) dl_config = _get_dataloader_config(config, batch_size_override) evaluation = Evaluation(loss_fn=LczeroLoss(config=config.training.losses)) dumper = Dumper(dump_to_stdout, dump_to_file, dump_to_shelve, dump_to_json) onnx_comparator = OnnxComparator(onnx_model) if onnx_model else None samples_to_process = num_samples if num_samples is not None else 10 logger.info("Starting evaluation with %d samples", samples_to_process) try: evaluation.run( model=model, datagen=from_dataloader(make_dataloader(dl_config)), num_samples=samples_to_process, dumper=dumper, onnx_comparator=onnx_comparator, softmax_jax_wdl=softmax_jax_wdl, ) finally: dumper.close() if onnx_comparator: onnx_comparator.log_summary() logger.info("Evaluation complete") ================================================ FILE: src/lczero_training/training/init.py ================================================ import gzip import logging import os import sys from typing import Optional import orbax.checkpoint as ocp from flax import nnx from google.protobuf import text_format from lczero_training.convert.leela_to_jax import ( LeelaImportOptions, fix_older_weights_file, leela_to_jax, ) from lczero_training.convert.leela_to_modelconfig import leela_to_modelconfig from lczero_training.training.state import TrainingState from proto import hlo_pb2, net_pb2 from proto.model_config_pb2 import ModelConfig from proto.root_config_pb2 import RootConfig logger = logging.getLogger(__name__) def _load_lc0_model_state( path: str, expected_config: ModelConfig, compute_dtype: hlo_pb2.XlaShapeProto.Type, ignore_config_mismatch: bool = False, ) -> tuple[nnx.State, int]: """Load lc0 weights, validate config, return (model_state, training_steps).""" lc0_weights = net_pb2.Net() with gzip.open(path, "rb") as f: lc0_weights.ParseFromString(f.read()) fix_older_weights_file(lc0_weights) leela_config = leela_to_modelconfig( lc0_weights, hlo_pb2.XlaShapeProto.F32, compute_dtype ) if leela_config != expected_config: if ignore_config_mismatch: logger.warning( "The provided lczero model configuration " "differs from the one in the config file (ignored)." ) else: logger.error( "The provided lczero model configuration " "differs from the one in the config file." ) logger.error(f"Config file model config: {expected_config}") logger.error(f"Leela model config: {leela_config}") sys.exit(1) import_options = LeelaImportOptions( weights_dtype=hlo_pb2.XlaShapeProto.F32, compute_dtype=compute_dtype ) model_state = leela_to_jax(lc0_weights, import_options) return model_state, lc0_weights.training_params.training_steps def init( config_filename: str, lczero_model: Optional[str], seed: int = 42, dry_run: bool = False, swa_initial_nets: int = 0, override_training_steps: Optional[int] = None, overwrite: bool = False, no_copy_swa: bool = False, ignore_config_mismatch: bool = False, ) -> None: """ Initializes a new training run. """ config = RootConfig() logger.info("Reading configuration from proto file") with open(config_filename, "r") as f: text_format.Parse(f.read(), config) checkpoint_path = config.training.checkpoint.path checkpoint_exists = checkpoint_path and os.path.exists(checkpoint_path) if not dry_run and checkpoint_exists and not overwrite: logger.error(f"Checkpoint path {checkpoint_path} already exists.") sys.exit(1) logger.info("Creating initial training state from configuration") training_state = TrainingState.new_from_config( model_config=config.model, training_config=config.training, ) if checkpoint_exists: logger.info(f"Loading from existing checkpoint: {checkpoint_path}") source_mgr = ocp.CheckpointManager( checkpoint_path, options=ocp.CheckpointManagerOptions(create=False), ) training_state = source_mgr.restore( source_mgr.latest_step(), args=ocp.args.PyTreeRestore(training_state), ) swa_enabled = config.training.HasField("swa") if lczero_model is None: if override_training_steps is not None: training_state = training_state.with_updated_step( override_training_steps ) if swa_enabled and swa_initial_nets > 0: training_state = training_state.replace( jit_state=training_state.jit_state.replace( num_averages=float(swa_initial_nets), ) ) else: logger.info(f"Loading lczero model: {lczero_model}") model_state, lc0_steps = _load_lc0_model_state( lczero_model, config.model, config.model.defaults.compute_dtype, ignore_config_mismatch, ) step = override_training_steps or lc0_steps new_swa_state = ( training_state.jit_state.swa_state if no_copy_swa else (model_state if swa_enabled else None) ) training_state = training_state.replace( jit_state=training_state.jit_state.replace( model_state=model_state, swa_state=new_swa_state, num_averages=float(swa_initial_nets) if swa_enabled else 0.0, ) ).with_updated_step(step) if dry_run: logger.info( f"Would save checkpoint to {config.training.checkpoint.path} " f"at step {training_state.jit_state.step}" ) else: checkpoint_mgr = ocp.CheckpointManager( config.training.checkpoint.path, options=ocp.CheckpointManagerOptions(create=True), ) step = training_state.jit_state.step if step in checkpoint_mgr.all_steps(): logger.info(f"Deleting existing checkpoint at step {step}") checkpoint_mgr.delete(step) checkpoint_mgr.wait_until_finished() logger.info( f"Saving checkpoint to {config.training.checkpoint.path} at step {step}" ) checkpoint_mgr.save(step=step, args=ocp.args.PyTreeSave(training_state)) checkpoint_mgr.wait_until_finished() logger.info("Initialization complete.") ================================================ FILE: src/lczero_training/training/lr_schedule.py ================================================ from typing import Callable, Sequence import jax.numpy as jnp import optax from proto.training_config_pb2 import LrSchedule def _create_rule_fn(rule: LrSchedule) -> Callable: """ Creates a JAX-compatible function for a single LR schedule rule. All data from the protobuf is extracted here and captured by the closure of the returned function. This avoids protobuf parsing inside the main schedule function which will be JIT-compiled. """ start_step = float(rule.starting_step) durations = list(rule.duration_steps) lrs = list(rule.lr) is_looping = rule.loop # Handle simple cases where the LR is constant for this rule. if not durations or not lrs: lr_val = jnp.asarray(lrs[-1] if lrs else 0.0, dtype=jnp.float32) # Return a simple lambda that ignores the step and returns the constant value. return lambda step: lr_val period = sum(durations) if period == 0.0: lr_val = jnp.asarray(lrs[-1], dtype=jnp.float32) return lambda step: lr_val # Pre-calculate JAX arrays for use in the schedule function. transitions = [ ( rule.transition[i] if i < len(rule.transition) else LrSchedule.Transition.CONSTANT ) for i in range(len(durations)) ] durs_j = jnp.asarray(durations, dtype=jnp.float32) ends_j = jnp.cumsum(durs_j) starts_j = ends_j - durs_j # Interpolation start/end LRs for each segment. a_vals = [ lrs[i] if i < len(lrs) else lrs[-1] for i in range(len(durations)) ] b_vals = [ lrs[i + 1] if (i + 1) < len(lrs) else a_vals[i] for i in range(len(durations)) ] lrs_a_j = jnp.asarray(a_vals, dtype=jnp.float32) lrs_b_j = jnp.asarray(b_vals, dtype=jnp.float32) trans_j = jnp.asarray(transitions, dtype=jnp.int32) last_lr_j = jnp.asarray(lrs[-1], dtype=jnp.float32) period_j = jnp.asarray(period, dtype=jnp.float32) def rule_fn(step: jnp.ndarray) -> jnp.ndarray: """JAX-compatible function evaluating the LR for a given step.""" rel_step = step - start_step if is_looping: rel_step = jnp.mod(rel_step, period_j) # Find active segment and calculate interpolation factor `t`. is_in_segment = ( (rel_step >= starts_j) & (rel_step < ends_j) & (durs_j > 0) ) # Use maximum() to avoid division by zero for zero-duration segments. t = jnp.clip((rel_step - starts_j) / jnp.maximum(durs_j, 1.0), 0.0, 1.0) # Calculate interpolated values for all segments for all transition types. lin = lrs_a_j + (lrs_b_j - lrs_a_j) * t cos = lrs_a_j + 0.5 * (1.0 - jnp.cos(jnp.pi * t)) * (lrs_b_j - lrs_a_j) # Select interpolation type for each segment. Default to CONSTANT (lrs_a_j). interp_vals = jnp.where( trans_j == LrSchedule.Transition.LINEAR, lin, lrs_a_j ) interp_vals = jnp.where( trans_j == LrSchedule.Transition.COSINE, cos, interp_vals ) # Select the value from the active segment by masking. lr = jnp.sum(interp_vals * is_in_segment) # If not in any segment (e.g., gap between segments), use the last LR value. lr = jnp.where(jnp.any(is_in_segment), lr, last_lr_j) if not is_looping: # For non-looping rules, if past the end, clamp to the last LR. lr = jnp.where(rel_step >= period_j, last_lr_j, lr) return lr return rule_fn def make_lr_schedule(schedules: Sequence[LrSchedule]) -> optax.Schedule: """ Creates a learning rate schedule from a sequence of LrSchedule protobufs. The schedule is composed of multiple rules, each active for a certain range of training steps. """ if not schedules: return lambda count: jnp.asarray(0.0, dtype=jnp.float32) rule_fns = [_create_rule_fn(rule) for rule in schedules] start_steps = jnp.asarray( [rule.starting_step for rule in schedules], dtype=jnp.float32 ) # Determine the learning rate to use for steps before the first rule begins. first_lrs = jnp.asarray( [rule.lr[0] if rule.lr else 0.0 for rule in schedules], dtype=jnp.float32, ) earliest_rule_idx = jnp.argmin(start_steps) pre_start_lr = first_lrs[earliest_rule_idx] min_start_step = start_steps[earliest_rule_idx] def schedule(count: jnp.ndarray) -> jnp.ndarray: """The actual schedule function passed to Optax.""" step = jnp.asarray(count, dtype=jnp.float32) # Find the index of the active rule. The active rule is the one with the # largest starting_step that is less than or equal to the current step. eligible_mask = step >= start_steps # Replace non-eligible start_steps with a large negative number so they # are ignored by argmax. effective_starts = jnp.where(eligible_mask, start_steps, -1.0) active_rule_idx = jnp.argmax(effective_starts) all_lrs = jnp.stack([fn(step) for fn in rule_fns]) lr = all_lrs[active_rule_idx] # If the current step is before any rule starts, use the pre-start LR. return jnp.where(step < min_start_step, pre_start_lr, lr) return schedule ================================================ FILE: src/lczero_training/training/migrate_checkpoint.py ================================================ from typing import Any, Dict, Iterable, List, Set, Tuple import jax import numpy as np import orbax.checkpoint as ocp from flax import serialization from google.protobuf import text_format from orbax.checkpoint.utils import tuple_path_from_keypath from lczero_training.training import state as state_lib from proto import checkpoint_migration_config_pb2, root_config_pb2 def _str_to_key_path(path_str: str) -> tuple[str, ...]: return tuple(path_str.split(".")) def _load_new_state( root_config: root_config_pb2.RootConfig, serialized_model: bool ) -> Any: new_state = state_lib.TrainingState.new_from_config( root_config.model, root_config.training ) if serialized_model: return serialization.to_state_dict(new_state) return new_state def load_checkpoint( checkpoint_path: str, checkpoint_step: int | None = None ) -> Tuple[Any, int]: """Load a checkpoint from the given path. Args: checkpoint_path: Path to the checkpoint directory. checkpoint_step: Step to load, or None to load the latest. Returns: Tuple of (checkpoint_state, checkpoint_step). """ manager = ocp.CheckpointManager( checkpoint_path, options=ocp.CheckpointManagerOptions(create=False), ) if checkpoint_step is None: checkpoint_step = manager.latest_step() if checkpoint_step is None: raise ValueError(f"No checkpoints found in {checkpoint_path}") return manager.restore(checkpoint_step), checkpoint_step def get_checkpoint_steps( checkpoint_path: str, min_step: int | None = None, max_step: int | None = None, ) -> list[int]: """Get all checkpoint steps in the given range. Args: checkpoint_path: Path to the checkpoint directory. min_step: Minimum step (inclusive), or None for no minimum. max_step: Maximum step (inclusive), or None for no maximum. Returns: List of checkpoint steps in ascending order. """ manager = ocp.CheckpointManager( checkpoint_path, options=ocp.CheckpointManagerOptions(create=False), ) all_steps = sorted(manager.all_steps()) filtered_steps = [] for step in all_steps: if min_step is not None and step < min_step: continue if max_step is not None and step > max_step: continue filtered_steps.append(step) filtered_steps.sort() return filtered_steps def _load_old_state( checkpoint_path: str, checkpoint_step: int | None ) -> Tuple[Any, int]: return load_checkpoint(checkpoint_path, checkpoint_step) def load_migration_rules(rules_file: str | None) -> List[Tuple[Any, Any]]: """Load migration rules from a CheckpointMigrationConfig file. Args: rules_file: Path to the CheckpointMigrationConfig textproto file, or None to return empty rules. Returns: List of (from_path, to_path) tuples representing migration rules. """ rules = [] if rules_file: migration_config = ( checkpoint_migration_config_pb2.CheckpointMigrationConfig() ) with open(rules_file, "r") as f: text_format.Parse(f.read(), migration_config) for rule_proto in migration_config.rule: from_path = ( _str_to_key_path(rule_proto.from_path) if rule_proto.from_path else None ) to_path = ( _str_to_key_path(rule_proto.to_path) if rule_proto.to_path else None ) rules.append((from_path, to_path)) return rules def _format_value(value: Any) -> str: if isinstance(value, (np.ndarray, jax.Array)): return f"{value.dtype}{value.shape}" return repr(value) def _format_path_diff( unhandled_source: Set[Tuple[str, ...]], unhandled_dest: Set[Tuple[str, ...]], old_paths: Dict[Tuple[str, ...], Any], new_paths: Dict[Tuple[str, ...], Any], ) -> str: diff = [] for p in sorted(list(unhandled_source | unhandled_dest)): p_str = ".".join(p) if p in unhandled_source: diff.append(f"- {p_str}: {_format_value(old_paths[p])}") if p in unhandled_dest: diff.append(f"+ {p_str}: {_format_value(new_paths[p])}") return "\n".join(diff) class Migration: def __init__(self, old_state: Any, new_state: Any): old_leaves = jax.tree_util.tree_leaves_with_path(old_state) new_flat, self.new_treedef = jax.tree_util.tree_flatten_with_path( new_state ) self.old_paths: Dict[Tuple[str, ...], Any] = { tuple_path_from_keypath(path): value for path, value in old_leaves } self.new_paths: Dict[Tuple[str, ...], Any] = { tuple_path_from_keypath(path): value for path, value in new_flat } self.new_leaves: List[Any] = [value for _, value in new_flat] self.new_path_to_idx: Dict[Tuple[str, ...], int] = { tuple_path_from_keypath(path): i for i, (path, _) in enumerate(new_flat) } self.source_paths: Set[Tuple[str, ...]] = set(self.old_paths.keys()) self.dest_paths: Set[Tuple[str, ...]] = set(self.new_path_to_idx.keys()) print(f"{len(self.source_paths & self.dest_paths)} common keys") print(f"{len(self.source_paths - self.dest_paths)} keys disappeared") print(f"{len(self.dest_paths - self.source_paths)} keys appeared") self.errors: List[str] = [] def _apply_move_rule( self, from_path: Tuple[str, ...], to_path: Tuple[str, ...] ) -> None: if from_path == to_path: self.errors.append( f"from_path and to_path are the same: {from_path}" ) return source_prefixed = { p for p in self.source_paths if p[: len(from_path)] == from_path } dest_prefixed = { p for p in self.dest_paths if p[: len(to_path)] == to_path } if not source_prefixed: self.errors.append(f"from_path {from_path} not found in old state") if not dest_prefixed: self.errors.append(f"to_path {to_path} not found in new state") for p in source_prefixed: new_p = to_path + p[len(from_path) :] if new_p in self.dest_paths: idx = self.new_path_to_idx[new_p] self.new_leaves[idx] = self.old_paths[p] self.source_paths.remove(p) self.dest_paths.remove(new_p) else: self.errors.append(f"Path {new_p} not found in new state") def _apply_ignore_rule(self, from_path: Tuple[str, ...]) -> None: source_prefixed = { p for p in self.source_paths if p[: len(from_path)] == from_path } if not source_prefixed: self.errors.append(f"from_path {from_path} not found in old state") self.source_paths -= source_prefixed def _apply_keep_rule(self, to_path: Tuple[str, ...]) -> None: dest_prefixed = { p for p in self.dest_paths if p[: len(to_path)] == to_path } if not dest_prefixed: self.errors.append(f"to_path {to_path} not found in new state") self.dest_paths -= dest_prefixed def apply_rules(self, rules: List[Tuple[Any, Any]]) -> None: for from_path, to_path in rules: if from_path and to_path: self._apply_move_rule(from_path, to_path) elif from_path: self._apply_ignore_rule(from_path) elif to_path: self._apply_keep_rule(to_path) def run(self, rules: List[Tuple[Any, Any]]) -> Any: self.apply_rules(rules) # Copy remaining paths copied_paths = self.source_paths & self.dest_paths for p in copied_paths: idx = self.new_path_to_idx[p] self.new_leaves[idx] = self.old_paths[p] self.source_paths -= copied_paths self.dest_paths -= copied_paths unhandled_source = self.source_paths unhandled_dest = self.dest_paths if unhandled_source or unhandled_dest: self.errors.append( "Unmapped paths:\n" + _format_path_diff( unhandled_source, unhandled_dest, self.old_paths, self.new_paths, ) ) if self.errors: raise ValueError("\n".join(self.errors)) return getattr(self.new_treedef, "unflatten")(self.new_leaves) def _save_checkpoint( migrated_state: Any, new_checkpoint_path: str, new_checkpoint_step: int, overwrite: bool, ) -> None: manager = ocp.CheckpointManager( new_checkpoint_path, ocp.PyTreeCheckpointer(), options=ocp.CheckpointManagerOptions( create=True, save_interval_steps=1, todelete_subdir="trash" ), ) if new_checkpoint_step in manager.all_steps(): if overwrite: manager.delete(new_checkpoint_step) manager.wait_until_finished() else: raise ValueError( f"Checkpoint already exists at {new_checkpoint_step} in " f"{new_checkpoint_path}. " "Use --overwrite to overwrite." ) manager.save(new_checkpoint_step, migrated_state) manager.wait_until_finished() print( f"New checkpoint saved successfully to {new_checkpoint_path} at step " f"{new_checkpoint_step}." ) def _dump_paths(paths: Iterable[Tuple[str, ...]], field: str) -> None: def key(p: Tuple[str, ...]) -> tuple: return tuple(int(c) if c.isdigit() else c for c in p) for path in sorted(paths, key=key): print(f'rule {{ {field}: "{".".join(path)}" }}') def migrate_checkpoint( config: str, new_checkpoint: str | None, overwrite: bool, rules_file: str | None, serialized_model: bool, checkpoint_step: int | None, new_checkpoint_step: int | None, dump_source_paths: bool = False, dump_destination_paths: bool = False, ) -> None: """Migrates a checkpoint to a new training state.""" root_config = root_config_pb2.RootConfig() with open(config, "r") as f: text_format.Parse(f.read(), root_config) new_state = _load_new_state(root_config, serialized_model) old_state, old_checkpoint_step = _load_old_state( root_config.training.checkpoint.path, checkpoint_step ) rules = load_migration_rules(rules_file) migration = Migration(old_state, new_state) if dump_source_paths: _dump_paths(migration.old_paths.keys(), "from_path") if dump_destination_paths: _dump_paths(migration.new_paths.keys(), "to_path") if dump_source_paths or dump_destination_paths: return migrated_state = migration.run(rules) if new_checkpoint: checkpoint_path = new_checkpoint elif overwrite: checkpoint_path = root_config.training.checkpoint.path else: print("Migration check successful.") return if new_checkpoint_step is None: new_checkpoint_step = old_checkpoint_step _save_checkpoint( migrated_state, checkpoint_path, new_checkpoint_step, overwrite ) ================================================ FILE: src/lczero_training/training/optimizer.py ================================================ from functools import partial import jax import jax.numpy as jnp import optax from flax import nnx from lczero_training.training.utils import make_weights_mask from proto.training_config_pb2 import OptimizerConfig _STATES_WITH_COUNT = ( optax.ScaleByAdamState, optax.ScaleByScheduleState, ) def update_optimizer_step( opt_state: optax.OptState, step: int ) -> optax.OptState: """Updates all step counters in the optimizer state tree.""" step_array = jnp.array(step, dtype=jnp.int32) def has_count(x: object) -> bool: return hasattr(x, "_fields") and "count" in x._fields def update_count(x: optax.OptState) -> optax.OptState: if not has_count(x): return x assert isinstance(x, _STATES_WITH_COUNT), ( f"Unexpected state type with 'count' field: {type(x).__name__}" ) return x._replace(count=step_array) return jax.tree_util.tree_map(update_count, opt_state, is_leaf=has_count) def make_gradient_transformation( config: OptimizerConfig, *, max_grad_norm: float | None = None, lr_schedule: optax.Schedule, ) -> optax.GradientTransformation: if config.HasField("nadamw"): nadamw = config.nadamw tx = optax.nadamw( lr_schedule, b1=nadamw.beta_1, b2=nadamw.beta_2, eps=nadamw.epsilon, weight_decay=nadamw.weight_decay, mask=partial(make_weights_mask, nadamw.decay_selector), ) elif config.HasField("nadam"): nadam = config.nadam tx = optax.nadam( lr_schedule, b1=nadam.beta_1, b2=nadam.beta_2, eps=nadam.epsilon, ) elif config.HasField("sgd"): sgd = config.sgd tx = optax.sgd( lr_schedule, momentum=sgd.momentum if sgd.momentum else None, nesterov=sgd.nesterov, ) else: raise ValueError( "Unsupported optimizer type: {}".format( config.WhichOneof("optimizer_type") ) ) if max_grad_norm is not None and max_grad_norm > 0: tx = optax.chain(optax.clip_by_global_norm(max_grad_norm), tx) if config.HasField("freeze_selector"): freeze_mask = partial(make_weights_mask, config.freeze_selector) def trainable_mask(p: nnx.State) -> nnx.State: return jax.tree.map(lambda x: not x, freeze_mask(p)) tx = optax.chain( optax.masked(tx, trainable_mask), optax.masked(optax.set_to_zero(), freeze_mask), ) return tx ================================================ FILE: src/lczero_training/training/overfit.py ================================================ """Overfitting utility for quickly validating training setup.""" import csv import logging from contextlib import suppress from functools import partial from typing import Any import jax import jax.numpy as jnp import numpy as np from flax import nnx from google.protobuf import text_format from jax import tree_util from lczero_training.dataloader import DataLoader, make_dataloader from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel from lczero_training.training.lr_schedule import make_lr_schedule from lczero_training.training.optimizer import make_gradient_transformation from lczero_training.training.state import ( TrainingBatch, TrainingSample, TrainingState, ) from lczero_training.training.training import Training from proto.root_config_pb2 import RootConfig logger = logging.getLogger(__name__) def _stop_loader(loader: DataLoader) -> None: with suppress(Exception): loader.stop() def _prepare_batch(batch_tuple: tuple) -> TrainingBatch: # DataLoader now returns tuple: (inputs, probabilities, values) return TrainingBatch( inputs=jnp.asarray(batch_tuple[0]), probabilities=jnp.asarray(batch_tuple[1]), values=jnp.asarray(batch_tuple[2]), ) def _make_eval_step(graphdef: nnx.GraphDef, loss_fn: LczeroLoss) -> Any: @partial(nnx.jit, static_argnames=()) def eval_step( model_state: nnx.State, batch: TrainingBatch ) -> tuple[jax.Array, Any]: model = nnx.merge(graphdef, model_state) def loss_for_batch( model_arg: LczeroModel, sample_arg: TrainingSample ) -> tuple[jax.Array, Any]: return loss_fn(model_arg, sample_arg) loss_vfn = jax.vmap(loss_for_batch, in_axes=(None, 0), out_axes=0) # vmap automatically distributes TrainingBatch over batch dimension, # calling loss_for_batch with TrainingSample (single samples). per_sample_loss, unweighted_losses = loss_vfn(model, batch) # type: ignore[arg-type] mean_loss = jnp.mean(per_sample_loss) mean_unweighted = tree_util.tree_map(jnp.mean, unweighted_losses) return mean_loss, mean_unweighted return eval_step def overfit( *, config_filename: str, num_steps: int, coin_flip: bool = False, csv_file: str | None = None, ) -> None: """Runs an overfitting loop to validate training.""" if num_steps <= 0: raise ValueError("num_steps must be a positive integer") if jax.device_count() > 1: raise ValueError( f"Overfit utility does not support multi-GPU training. " f"Detected {jax.device_count()} devices. " f"Please set CUDA_VISIBLE_DEVICES to use only one GPU." ) config = RootConfig() logger.info("Reading configuration from proto file") with open(config_filename, "r") as config_file: text_format.Parse(config_file.read(), config) logger.info("Creating data loader and fetching batches") loader = make_dataloader(config.data_loader) try: batch_a = loader.get_next() batch_b = loader.get_next() if coin_flip else None finally: _stop_loader(loader) prepared_batch_a = _prepare_batch(batch_a) prepared_batch_b = _prepare_batch(batch_b) if batch_b is not None else None logger.info("Creating training state from configuration") training_state = TrainingState.new_from_config( model_config=config.model, training_config=config.training, ) graphdef, _ = nnx.split( LczeroModel(config=config.model, rngs=nnx.Rngs(params=42)) ) jit_state = training_state.jit_state lr_sched = make_lr_schedule(config.training.lr_schedule) optimizer_tx = make_gradient_transformation( config.training.optimizer, max_grad_norm=getattr(config.training, "max_grad_norm", 0.0), lr_schedule=lr_sched, ) loss_fn = LczeroLoss(config=config.training.losses) training = Training( optimizer_tx=optimizer_tx, graphdef=graphdef, loss_fn=loss_fn, ) eval_step = _make_eval_step(graphdef, loss_fn) csv_handle = None csv_writer: Any | None = None if csv_file is not None: logger.info("Writing overfit results to %s", csv_file) csv_handle = open(csv_file, "w", newline="") csv_writer = csv.writer(csv_handle) csv_writer.writerow( [ "step", "train_batch", "train_loss", "train_unweighted", "eval_batch", "eval_loss", "eval_unweighted", ] ) csv_handle.flush() def log_step( *, step_value: int, train_batch_name: str, train_loss: float, train_unweighted: Any, eval_batch_name: str | None, eval_loss: float | None, eval_unweighted: Any | None, ) -> None: if eval_batch_name is None or eval_loss is None: logger.info( "Step %d: batch=%s train_loss=%f, unweighted_losses=%s", step_value, train_batch_name, train_loss, train_unweighted, ) else: logger.info( ( "Step %d: trained %s train_loss=%f, unweighted_losses=%s; " "evaluated %s eval_loss=%f, eval_unweighted=%s" ), step_value, train_batch_name, train_loss, train_unweighted, eval_batch_name, eval_loss, eval_unweighted, ) if csv_writer is not None and csv_handle is not None: csv_writer.writerow( [ step_value, train_batch_name, train_loss, repr(train_unweighted), eval_batch_name or "", "" if eval_loss is None else eval_loss, "" if eval_unweighted is None else repr(eval_unweighted), ] ) csv_handle.flush() try: if coin_flip: if prepared_batch_b is None: raise RuntimeError( "Coin flip mode requires two batches but only one was fetched" ) logger.info( "Starting coin-flip overfit: %d steps on batch A then %d on batch B", num_steps, num_steps, ) def run_phase( train_batch: TrainingBatch, train_name: str, eval_batch: TrainingBatch, eval_name: str, ) -> None: nonlocal jit_state for _ in range(num_steps): jit_state, metrics = training.train_step( optimizer_tx, jit_state, train_batch, ) loss = metrics["loss"] unweighted_losses = metrics["unweighted_losses"] loss_value, unweighted_host = jax.device_get( (loss, unweighted_losses) ) loss_value = float(np.asarray(loss_value)) unweighted_host = tree_util.tree_map( lambda x: float(np.asarray(x)), unweighted_host ) eval_loss, eval_unweighted = eval_step( jit_state.model_state, eval_batch ) eval_loss, eval_unweighted = jax.device_get( (eval_loss, eval_unweighted) ) eval_loss_value = float(np.asarray(eval_loss)) eval_unweighted_host = tree_util.tree_map( lambda x: float(np.asarray(x)), eval_unweighted ) step_value = int( np.asarray(jax.device_get(jit_state.step)).flat[0] ) log_step( step_value=step_value, train_batch_name=train_name, train_loss=loss_value, train_unweighted=unweighted_host, eval_batch_name=eval_name, eval_loss=eval_loss_value, eval_unweighted=eval_unweighted_host, ) run_phase(prepared_batch_a, "A", prepared_batch_b, "B") run_phase(prepared_batch_b, "B", prepared_batch_a, "A") else: logger.info("Starting overfit loop for %d steps", num_steps) for _ in range(num_steps): jit_state, metrics = training.train_step( optimizer_tx, jit_state, prepared_batch_a, ) loss = metrics["loss"] unweighted_losses = metrics["unweighted_losses"] loss_value, unweighted_host = jax.device_get( (loss, unweighted_losses) ) loss_value = float(np.asarray(loss_value)) unweighted_host = tree_util.tree_map( lambda x: float(np.asarray(x)), unweighted_host ) step_value = int( np.asarray(jax.device_get(jit_state.step)).flat[0] ) log_step( step_value=step_value, train_batch_name="single", train_loss=loss_value, train_unweighted=unweighted_host, eval_batch_name=None, eval_loss=None, eval_unweighted=None, ) finally: if csv_handle is not None: csv_handle.close() ================================================ FILE: src/lczero_training/training/state.py ================================================ import dataclasses import logging from typing import Any, Optional, Union import jax import jax.numpy as jnp import jax.sharding as jshard import numpy as np import optax from flax import nnx from flax.struct import dataclass from lczero_training.model.model import LczeroModel from lczero_training.training.lr_schedule import make_lr_schedule from lczero_training.training.optimizer import ( make_gradient_transformation, update_optimizer_step, ) from proto.model_config_pb2 import ModelConfig from proto.training_config_pb2 import TrainingConfig logger = logging.getLogger(__name__) @jax.tree_util.register_dataclass @dataclasses.dataclass class TrainingSample: """Single training sample without batch dimension. Used for vmap over individual samples in loss computation. Fields: inputs: Input planes tensor [112, 8, 8] probabilities: Policy probabilities tensor [1858] values: Combined values tensor [6, 3] where: - Index 0: result [result_q, result_d, plies_left] - Index 1: best [best_q, best_d, best_m] - Index 2: played [played_q, played_d, played_m] - Index 3: orig [orig_q, orig_d, orig_m] (may contain NaN) - Index 4: root [root_q, root_d, root_m] - Index 5: st [q_st, d_st, NaN] """ inputs: jax.Array probabilities: jax.Array values: jax.Array @jax.tree_util.register_dataclass @dataclasses.dataclass class TrainingBatch: """Batch of training data with inputs, probabilities, and values tensors. Fields: inputs: Input planes tensor [batch, 112, 8, 8] probabilities: Policy probabilities tensor [batch, 1858] values: Combined values tensor [batch, 6, 3] where: - Index 0: result [result_q, result_d, plies_left] - Index 1: best [best_q, best_d, best_m] - Index 2: played [played_q, played_d, played_m] - Index 3: orig [orig_q, orig_d, orig_m] (may contain NaN) - Index 4: root [root_q, root_d, root_m] - Index 5: st [q_st, d_st, NaN] """ inputs: Union[jax.Array, jshard.NamedSharding] probabilities: Union[jax.Array, jshard.NamedSharding] values: Union[jax.Array, jshard.NamedSharding] @classmethod def from_tuple( cls, tensor_tuple: tuple[np.ndarray, ...] ) -> "TrainingBatch": """Create TrainingBatch from tuple returned by DataLoader.""" if len(tensor_tuple) != 3: raise ValueError( f"Expected tuple of 3 tensors, got {len(tensor_tuple)}" ) return cls( inputs=jnp.asarray(tensor_tuple[0]), probabilities=jnp.asarray(tensor_tuple[1]), values=jnp.asarray(tensor_tuple[2]), ) @dataclass class JitTrainingState: step: int model_state: nnx.State opt_state: Optional[optax.OptState] # SWA state mirrors model_state structure when enabled; None otherwise. # Marked non-pytree to exclude from JIT/pjit inputs and device transfers. swa_state: Optional[nnx.State] # Effective number of model snapshots accumulated into SWA (can be fractional). num_averages: float def replace(self, **changes: Any) -> "JitTrainingState": """Returns a new instance of the class with the specified changes.""" return dataclasses.replace(self, **changes) @dataclass class TrainingState: jit_state: JitTrainingState # Last chunk source that was available when the last epoch started training. num_heads: int last_chunk_source: str = "" def replace(self, **changes: Any) -> "TrainingState": """Returns a new instance of the class with the specified changes.""" return dataclasses.replace(self, **changes) def with_updated_step(self, step: int) -> "TrainingState": """Returns a copy with updated step in both jit_state and optimizer.""" updated_opt_state = ( update_optimizer_step(self.jit_state.opt_state, step) if self.jit_state.opt_state is not None else None ) return self.replace( jit_state=self.jit_state.replace( step=step, opt_state=updated_opt_state, ) ) @staticmethod def new_from_config( model_config: ModelConfig, training_config: TrainingConfig ) -> "TrainingState": rngs = nnx.Rngs(params=42) model_state = nnx.state(LczeroModel(config=model_config, rngs=rngs)) lr_sched = make_lr_schedule(training_config.lr_schedule) opt_state = make_gradient_transformation( training_config.optimizer, max_grad_norm=getattr(training_config, "max_grad_norm", 0.0), lr_schedule=lr_sched, ).init(model_state) jit_state = JitTrainingState( step=0, model_state=model_state, opt_state=opt_state, swa_state=model_state, num_averages=0.0, ) return TrainingState( jit_state=jit_state, num_heads=model_config.encoder.heads, ) ================================================ FILE: src/lczero_training/training/tensorboard.py ================================================ """Utilities for writing training metrics to TensorBoard event files.""" from __future__ import annotations import logging from collections.abc import Mapping from typing import Any, Dict import jax import numpy as np from tensorboardX import SummaryWriter logger = logging.getLogger(__name__) MetricsDict = Dict[str, Any] def _to_ndarray(value: Any) -> np.ndarray: try: return np.asarray(jax.device_get(value)) except TypeError: return np.asarray(value) def _to_scalar(value: Any) -> float | None: array = _to_ndarray(value) if array.ndim == 0 or array.size == 1: return float(array.reshape(())) logger.warning( "Skipping non-scalar metric with shape %s when logging to TensorBoard.", array.shape, ) return None def _flatten_metrics( metrics: Mapping[str, Any], prefix: str = "" ) -> Dict[str, float]: scalars: Dict[str, float] = {} for key, value in metrics.items(): tag = f"{prefix}{key}" if prefix else key if isinstance(value, Mapping): scalars.update(_flatten_metrics(value, f"{tag}/")) continue scalar = _to_scalar(value) if scalar is not None: scalars[tag] = scalar return scalars def _to_step(step: Any) -> int: return int(_to_ndarray(step).reshape(())) class TensorboardLogger: """Writes scalar training metrics to TensorBoard.""" def __init__(self, logdir: str) -> None: self._writer = SummaryWriter(logdir) def log(self, step: int, metrics: MetricsDict) -> None: global_step = _to_step(step) for tag, value in _flatten_metrics(metrics).items(): self._writer.add_scalar( tag=tag, scalar_value=value, global_step=global_step ) def close(self) -> None: self._writer.close() ================================================ FILE: src/lczero_training/training/test_lr_schedule.py ================================================ from typing import Callable, List import jax.numpy as jnp import pytest from lczero_training.training.lr_schedule import make_lr_schedule from proto import training_config_pb2 as pb def _sched( schedules: List[pb.LrSchedule], ) -> Callable[[jnp.ndarray], jnp.ndarray]: return make_lr_schedule(schedules) def _val(s: Callable[[jnp.ndarray], jnp.ndarray], t: int | float) -> float: return float(s(jnp.asarray(t, dtype=jnp.float32))) def test_rule_selection_by_starting_step() -> None: r0 = pb.LrSchedule( starting_step=0, duration_steps=[5], lr=[0.1, 0.2], ) r1 = pb.LrSchedule( starting_step=10, duration_steps=[5], lr=[0.5, 0.5], ) sched = _sched([r0, r1]) assert _val(sched, 0) == pytest.approx(0.1) assert _val(sched, 9) == pytest.approx(0.2) assert _val(sched, 10) == pytest.approx(0.5) assert _val(sched, 100) == pytest.approx(0.5) def test_default_constant_transition_and_tail() -> None: r = pb.LrSchedule( starting_step=0, duration_steps=[5], lr=[0.3, 0.8], # no transition specified -> CONSTANT ) sched = _sched([r]) for t in range(5): assert _val(sched, t) == pytest.approx(0.3) # Beyond period (no loop) yields last lr assert _val(sched, 6) == pytest.approx(0.8) def test_linear_then_hold() -> None: r = pb.LrSchedule( starting_step=0, duration_steps=[3, 7], lr=[0.0, 0.9, 0.9], transition=[pb.LrSchedule.Transition.LINEAR], ) sched = _sched([r]) assert _val(sched, 0) == pytest.approx(0.0) assert _val(sched, 1) == pytest.approx(0.3) assert _val(sched, 2) == pytest.approx(0.6) assert _val(sched, 3) == pytest.approx(0.9) assert _val(sched, 8) == pytest.approx(0.9) def test_looping_constant_segments() -> None: r = pb.LrSchedule( starting_step=0, duration_steps=[3, 2], lr=[1.0, 2.0, 3.0], loop=True, ) sched = _sched([r]) assert _val(sched, 0) == pytest.approx(1.0) assert _val(sched, 2) == pytest.approx(1.0) assert _val(sched, 3) == pytest.approx(2.0) assert _val(sched, 5) == pytest.approx(1.0) # 5 % (3+2) == 0 def test_zero_duration_is_skipped() -> None: r = pb.LrSchedule( starting_step=0, duration_steps=[0, 5], lr=[1.0, 2.0, 3.0], transition=[ pb.LrSchedule.Transition.LINEAR, pb.LrSchedule.Transition.LINEAR, ], ) sched = _sched([r]) # Interpolates over the second interval [2.0 -> 3.0] assert _val(sched, 0) == pytest.approx(2.0) assert _val(sched, 2) == pytest.approx(2.4) assert _val(sched, 5) == pytest.approx(3.0) def test_chain_zero_durations_then_linear() -> None: r = pb.LrSchedule( starting_step=0, duration_steps=[0, 0, 0, 5], lr=[1.0, 2.0, 3.0, 4.0, 5.0], transition=[ pb.LrSchedule.Transition.LINEAR, pb.LrSchedule.Transition.LINEAR, pb.LrSchedule.Transition.LINEAR, pb.LrSchedule.Transition.LINEAR, ], ) sched = _sched([r]) # Should use last interval [4.0 -> 5.0] linearly across 5 steps assert _val(sched, 0) == pytest.approx(4.0) assert _val(sched, 2) == pytest.approx(4.4) assert _val(sched, 4) == pytest.approx(4.8) def test_cosine() -> None: r = pb.LrSchedule( starting_step=0, duration_steps=[4], lr=[0.0, 1.0], transition=[pb.LrSchedule.Transition.COSINE], ) sched = _sched([r]) # t=0 -> 0.0, t=0.5 -> 0.5, t=1.0 -> 1.0 approximately assert _val(sched, 0) == pytest.approx(0.0) assert _val(sched, 2) == pytest.approx(0.5, abs=1e-6) assert _val(sched, 4) == pytest.approx(1.0) def test_before_first_rule_uses_earliest_first_lr() -> None: r = pb.LrSchedule( starting_step=5, duration_steps=[3], lr=[0.1, 0.2], ) sched = _sched([r]) # Before the first rule starts, schedule returns the first lr of earliest rule. assert _val(sched, 0) == pytest.approx(0.1) ================================================ FILE: src/lczero_training/training/training.py ================================================ import dataclasses import logging from datetime import datetime from functools import partial from typing import Any, Callable, Dict, Generator, Optional, Tuple, cast import jax import jax.numpy as jnp import jax.sharding as jshard import numpy as np import optax from flax import nnx from jax import tree_util from jax.sharding import PartitionSpec as P from lczero_training.dataloader import DataLoader from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel from lczero_training.training.state import ( JitTrainingState, TrainingBatch, TrainingSample, ) from proto import training_config_pb2 as training_config_pb2 MetricsDict = Dict[str, Any] @dataclasses.dataclass class StepHookData: """Data passed to the step hook callback during training.""" global_step: int local_step: int steps_per_epoch: int metrics: MetricsDict jit_state: JitTrainingState StepHook = Callable[[StepHookData], None] logger = logging.getLogger(__name__) def from_dataloader( loader: DataLoader, ) -> Generator[tuple[np.ndarray, ...], None, None]: while True: yield loader.get_next() class Training: optimizer_tx: optax.GradientTransformation train_step: Callable[ [optax.GradientTransformation, JitTrainingState, TrainingBatch], Tuple[JitTrainingState, MetricsDict], ] _swa_config: Optional[training_config_pb2.SWAConfig] _dp_sharding: Optional[jshard.NamedSharding] def __init__( self, optimizer_tx: optax.GradientTransformation, graphdef: nnx.GraphDef, loss_fn: LczeroLoss, swa_config: Optional[training_config_pb2.SWAConfig] = None, ): self.optimizer_tx = optimizer_tx self._swa_config = swa_config self._dp_sharding = None jit_kwargs: Dict[str, Any] = { "static_argnames": ("optimizer_tx",), "donate_argnames": ("jit_state",), } if jax.device_count() > 1: num_devices = jax.device_count() logger.info( f"Multi-GPU training enabled: {num_devices} devices detected" ) mesh = jshard.Mesh(jax.devices(), axis_names=("batch",)) replicated = jshard.NamedSharding(mesh, P()) dp_sharding = jshard.NamedSharding(mesh, P("batch")) self._dp_sharding = dp_sharding batch_sharding = TrainingBatch( inputs=dp_sharding, probabilities=dp_sharding, values=dp_sharding, ) in_shardings = (replicated, batch_sharding) out_shardings = replicated jit_kwargs["in_shardings"] = in_shardings jit_kwargs["out_shardings"] = out_shardings @partial(jax.jit, **jit_kwargs) def _step( optimizer_tx: optax.GradientTransformation, jit_state: JitTrainingState, batch: TrainingBatch, ) -> Tuple[JitTrainingState, MetricsDict]: model = nnx.merge(graphdef, jit_state.model_state) def loss_for_grad( model_arg: LczeroModel, sample_arg: TrainingSample ) -> Tuple[jax.Array, Dict[str, jax.Array]]: return loss_fn(model_arg, sample_arg) loss_vfn = jax.vmap( loss_for_grad, in_axes=(None, 0), # (model_arg, sample_arg) out_axes=0, ) def mean_loss_for_grad( model_arg: LczeroModel, batch_arg: TrainingBatch ) -> Tuple[jax.Array, Dict[str, jax.Array]]: # vmap automatically distributes TrainingBatch over batch dimension, # calling loss_for_grad with TrainingSample (single samples). per_sample_data_loss, unweighted_losses = loss_vfn( model_arg, batch_arg, # type: ignore[arg-type] ) mean_loss = jnp.mean(per_sample_data_loss) return mean_loss, unweighted_losses grad_fn = nnx.value_and_grad(mean_loss_for_grad, has_aux=True) (mean_loss, unweighted_losses), mean_grads = grad_fn(model, batch) grad_norm = optax.global_norm(mean_grads) assert jit_state.opt_state is not None updates, new_opt_state = optimizer_tx.update( mean_grads, jit_state.opt_state, jit_state.model_state ) new_model_state = optax.apply_updates( jit_state.model_state, updates ) new_jit_state = jit_state.replace( step=jit_state.step + 1, model_state=new_model_state, opt_state=new_opt_state, ) mean_unweighted = tree_util.tree_map(jnp.mean, unweighted_losses) metrics: MetricsDict = { "loss": mean_loss, "unweighted_losses": mean_unweighted, "grad_norm": grad_norm, } return new_jit_state, metrics self.train_step = cast( Callable[ [optax.GradientTransformation, JitTrainingState, TrainingBatch], Tuple[JitTrainingState, MetricsDict], ], _step, ) @staticmethod @jax.jit def _swa_tree_map( alpha: jax.Array, beta: jax.Array, swa_state: nnx.State, model_state: nnx.State, ) -> nnx.State: return tree_util.tree_map( lambda a, b: alpha * a + beta * b, swa_state, model_state ) def update_swa( self, jit_state: JitTrainingState, weight: float ) -> JitTrainingState: """Update SWA using the provided weight for the current model. Assumes `jit_state.swa_state` is initialized and `_swa_config` present. """ logger.info( "Updating SWA model, weight=%f, num_averages=%f", weight, jit_state.num_averages, ) assert self._swa_config is not None assert jit_state.swa_state is not None assert weight > 0.0 max_num_averages = self._swa_config.num_averages denom = jit_state.num_averages + weight alpha = jit_state.num_averages / denom beta = weight / denom new_swa_state = self._swa_tree_map( jnp.array(alpha), jnp.array(beta), jit_state.swa_state, jit_state.model_state, ) new_num_averages = min( max_num_averages, jit_state.num_averages + weight ) return jit_state.replace( swa_state=new_swa_state, num_averages=new_num_averages ) def maybe_update_swa( self, jit_state: JitTrainingState, steps_completed: int, total_steps: int, ) -> JitTrainingState: """Optionally update SWA based on configured schedule and epoch progress. Returns the original jit_state when no update is scheduled. """ if self._swa_config is None: return jit_state period_steps = self._swa_config.period_steps assert period_steps > 0 if steps_completed % period_steps == 0: return self.update_swa(jit_state, 1.0) if steps_completed == total_steps: remainder = total_steps % period_steps return self.update_swa(jit_state, remainder / period_steps) return jit_state def _validate_and_prepare_batch( self, tensor_tuple: tuple[np.ndarray, ...] ) -> TrainingBatch: logger.info("Fetched batch from dataloader") # Convert tuple to TrainingBatch batch = TrainingBatch.from_tuple(tensor_tuple) # Ensure batch.inputs is jax.Array for shape access assert isinstance(batch.inputs, jax.Array) batch_size = batch.inputs.shape[0] if self._dp_sharding is not None: num_devices = jax.device_count() if batch_size % num_devices != 0: raise ValueError( f"Batch size {batch_size} must be divisible by device " f"count {num_devices} for multi-GPU training. " f"Per-device batch size would be " f"{batch_size / num_devices:.2f}" ) per_device_batch_size = batch_size // num_devices logger.info( f"Multi-GPU batch: {batch_size} total " f"({per_device_batch_size} per device)" ) if self._dp_sharding is not None: batch = jax.device_put(batch, self._dp_sharding) return batch def _log_step_metrics( self, step_value: int, local_step: int, num_steps: int, metrics: MetricsDict, ) -> None: loss = float(metrics["loss"]) unweighted_losses = { k: float(v) for k, v in metrics["unweighted_losses"].items() } grad_norm = float(metrics["grad_norm"]) logger.info( f"Step {step_value} ({local_step}/{num_steps}), Loss: {loss}, " f"Unweighted losses: {unweighted_losses}, Grad norm: {grad_norm}" ) def _execute_step_hook( self, step_hook: Optional[StepHook], step_value: int, local_step: int, num_steps: int, metrics: MetricsDict, jit_state: JitTrainingState, ) -> None: if step_hook is None: return hook_data = StepHookData( global_step=step_value, local_step=local_step, steps_per_epoch=num_steps, metrics=metrics, jit_state=jit_state, ) step_hook(hook_data) def run( self, jit_state: JitTrainingState, datagen: Generator[tuple[np.ndarray, ...], None, None], num_steps: int, step_hook: Optional[StepHook] = None, memory_profile_dir: Optional[str] = None, ) -> JitTrainingState: assert jit_state.opt_state is not None if self._dp_sharding is not None: replicated = jshard.NamedSharding(self._dp_sharding.mesh, P()) jit_state = jax.device_put(jit_state, replicated) for local_step in range(num_steps): logger.info(f"Starting step {jit_state.step}") if memory_profile_dir is not None: jax.profiler.save_device_memory_profile( f"{memory_profile_dir}/" f"{datetime.now().strftime('%Y%m%d-%H%M%S')}" f"_before_{int(jit_state.step)}.prof" ) batch = self._validate_and_prepare_batch(next(datagen)) jit_state, metrics = self.train_step( self.optimizer_tx, jit_state, batch ) step_value = int( np.asarray(jax.device_get(jit_state.step)).reshape(()) ) jit_state = self.maybe_update_swa( jit_state, local_step + 1, num_steps ) self._execute_step_hook( step_hook, step_value, local_step, num_steps, metrics, jit_state ) self._log_step_metrics(step_value, local_step, num_steps, metrics) return jit_state ================================================ FILE: src/lczero_training/training/tune_lr.py ================================================ import csv import logging import sys from contextlib import nullcontext from functools import partial from typing import Callable, Dict, List, Tuple, cast import jax import jax.numpy as jnp import matplotlib.pyplot as plt import optax import orbax.checkpoint as ocp from flax import nnx from google.protobuf import text_format from jax import tree_util from lczero_training.dataloader import make_dataloader from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel from lczero_training.training.state import TrainingState from proto.root_config_pb2 import RootConfig from .training import Training, from_dataloader logger = logging.getLogger(__name__) def _prepare_batch(batch_tuple: tuple) -> Dict: # DataLoader now returns tuple: (inputs, probabilities, values) return { "inputs": batch_tuple[0], "probabilities": batch_tuple[1], "values": batch_tuple[2], } def _make_optimizer_with_schedule( training_state: TrainingState, config: RootConfig, schedule: optax.Schedule, ) -> optax.GradientTransformation: max_grad_norm = getattr(config.training, "max_grad_norm", 0.0) opt_config = config.training.optimizer if opt_config.HasField("nadamw"): conf = opt_config.nadamw tx: optax.GradientTransformation = optax.nadamw( schedule, b1=conf.beta_1, b2=conf.beta_2, eps=conf.epsilon, weight_decay=conf.weight_decay, ) else: raise ValueError( f"Unsupported optimizer type: {opt_config.WhichOneof('optimizer_type')}" ) if max_grad_norm > 0: tx = optax.chain(optax.clip_by_global_norm(max_grad_norm), tx) if training_state.jit_state.opt_state is None: raise ValueError("Optimizer state must be available in the checkpoint.") return tx def _make_eval_step( graphdef: nnx.GraphDef, loss_fn: LczeroLoss ) -> Callable[[nnx.State, Dict], jax.Array]: @partial(nnx.jit, static_argnames=()) def eval_step(model_state: nnx.State, batch: Dict) -> jax.Array: model = nnx.merge(graphdef, model_state) def calculate_loss( model_arg: LczeroModel, batch_arg: Dict ) -> Tuple[jax.Array, Dict[str, jax.Array]]: return loss_fn(model_arg, **batch_arg) loss_vfn = jax.vmap(calculate_loss, in_axes=(None, 0), out_axes=0) per_sample_data_loss, _ = loss_vfn(model, batch) return jnp.mean(per_sample_data_loss) return cast(Callable[[nnx.State, Dict], jax.Array], eval_step) def _plot_results(results: List[Tuple[float, float]], plot_output: str) -> None: logger.info("Saving plot to %s", plot_output) lrs, losses = zip(*results) plt.figure(figsize=(10, 6)) plt.plot(lrs, losses, marker="o") plt.xscale("log") plt.xlabel("Learning Rate") plt.ylabel("Loss") plt.title("Learning Rate Finder") plt.grid(True, which="both", linestyle="--") plt.tight_layout() plt.savefig(plot_output, dpi=150) plt.close() def tune_lr( *, config_filename: str, start_lr: float, num_steps: int, multiplier: float = 1.01, warmup_steps: int = 0, warmup_lr: float | None = None, csv_output: str | None = None, plot_output: str | None = None, num_test_batches: int = 0, ) -> None: if num_steps <= 0 or start_lr <= 0 or multiplier <= 0 or warmup_steps < 0: logger.error( "num_steps, start_lr, and multiplier must be positive, " "and warmup_steps non-negative." ) sys.exit(1) if warmup_steps > 0 and (warmup_lr is None or warmup_lr <= 0): logger.error("warmup_lr must be a positive value when warmup_steps > 0") sys.exit(1) config = RootConfig() logger.info("Reading configuration from %s", config_filename) with open(config_filename, "r") as f: text_format.Parse(f.read(), config) if not config.training.checkpoint.path: logger.error("Checkpoint path must be set in the configuration.") sys.exit(1) checkpoint_mgr = ocp.CheckpointManager( config.training.checkpoint.path, options=ocp.CheckpointManagerOptions(create=False), ) logger.info("Creating state from configuration") empty_state = TrainingState.new_from_config( model_config=config.model, training_config=config.training ) logger.info("Restoring checkpoint from %s", config.training.checkpoint.path) training_state = checkpoint_mgr.restore( checkpoint_mgr.latest_step(), args=ocp.args.PyTreeRestore(empty_state) ) if training_state is None: logger.error("No checkpoint found.") sys.exit(1) logger.info("Restored checkpoint at step %d", training_state.jit_state.step) assert isinstance(training_state, TrainingState) model, _ = nnx.split( LczeroModel(config=config.model, rngs=nnx.Rngs(params=42)) ) datagen = from_dataloader(make_dataloader(config.data_loader)) # Prepare fixed validation batches only if requested (num_test_batches > 0). use_validation = num_test_batches > 0 if use_validation: logger.info("Fetching %d validation batches", num_test_batches) validation_batches = [ tree_util.tree_map(jnp.asarray, _prepare_batch(next(datagen))) for _ in range(num_test_batches) ] else: validation_batches = [] loss_fn = LczeroLoss(config=config.training.losses) eval_step = _make_eval_step(model, loss_fn) def avg_val_loss() -> float: assert use_validation total_loss = 0.0 for vb in validation_batches: total_loss += float( eval_step(training_state.jit_state.model_state, vb) ) return total_loss / float(num_test_batches) def train_one_step( training: Training, tx: optax.GradientTransformation ) -> float: nonlocal training_state batch = tree_util.tree_map(jnp.asarray, _prepare_batch(next(datagen))) new_jit_state, metrics = training.train_step( tx, training_state.jit_state, batch ) training_state = training_state.replace(jit_state=new_jit_state) return float(metrics["loss"]) # training batch loss def run_phase( *, steps: int, schedule: optax.Schedule, lr_at: Callable[[int], float], label: str, on_result: Callable[[float, float, float | None], None], ) -> None: start_step = training_state.jit_state.step def offset_schedule(count: jax.Array) -> jax.Array: return schedule(count - start_step) tx = _make_optimizer_with_schedule( training_state, config, offset_schedule ) training = Training(optimizer_tx=tx, graphdef=model, loss_fn=loss_fn) for i in range(steps): current_lr = lr_at(i) logger.info( "%s step %d/%d at lr %.8f", label, i + 1, steps, current_lr ) train_loss = train_one_step(training, tx) val_loss = avg_val_loss() if use_validation else None on_result(current_lr, train_loss, val_loss) if use_validation: logger.info( "%s at lr %.8f: train=%.6f, val=%.6f", label, current_lr, train_loss, cast(float, val_loss), ) else: logger.info( "%s train loss at lr %.8f: %.6f", label, current_lr, train_loss, ) results: List[Tuple[float, float]] = [] with ( open(csv_output, "w", newline="") if csv_output else nullcontext() ) as csv_file: writer = csv.writer(csv_file) if csv_file else None if writer: if use_validation: writer.writerow(["lr", "train_loss", "val_loss"]) else: writer.writerow(["lr", "train_loss"]) def on_result( lr: float, train_loss: float, val_loss: float | None ) -> None: results.append((lr, train_loss)) if writer and csv_file: if use_validation: writer.writerow([lr, train_loss, val_loss]) else: writer.writerow([lr, train_loss]) csv_file.flush() phases = [] if warmup_steps > 0 and warmup_lr is not None: phases.append( { "label": "Warmup", "steps": warmup_steps, "schedule": optax.constant_schedule(warmup_lr), "lr_at": lambda _: float(warmup_lr), } ) phases.append( { "label": "Sweep", "steps": num_steps, "schedule": optax.exponential_decay( start_lr, transition_steps=1, decay_rate=multiplier ), "lr_at": lambda i: start_lr * (multiplier**i), } ) for phase_params in phases: run_phase(**phase_params, on_result=on_result) if plot_output: _plot_results(results, plot_output) ================================================ FILE: src/lczero_training/training/utils.py ================================================ from pathlib import PurePosixPath from flax import nnx from proto.training_config_pb2 import WeightsSelector def make_weights_mask( selector: WeightsSelector, params: nnx.State ) -> nnx.State: """Creates a boolean mask based on WeightsSelector. True = include weight.""" def mask_fn(path: tuple[object, ...], _variable: nnx.Variable) -> bool: p = PurePosixPath(*map(str, path)) for rule in selector.rule: if p.full_match(rule.match): return rule.include return selector.otherwise_include return nnx.map_state(mask_fn, params) ================================================ FILE: src/lczero_training/tui/__init__.py ================================================ # ABOUTME: TUI package initialization for the training dashboard. # ABOUTME: Exports main TrainingTuiApp class for external use. from .app import TrainingTuiApp __all__ = ["TrainingTuiApp"] ================================================ FILE: src/lczero_training/tui/app.py ================================================ # ABOUTME: Main TUI application class implementing the training dashboard. # ABOUTME: Uses Textual framework to create a full-screen interface with four panes. import argparse import logging import os import signal import subprocess import sys from datetime import datetime from typing import Iterable, Optional import anyio from anyio.streams.text import TextReceiveStream, TextSendStream from textual.app import App, ComposeResult, SystemCommand from textual.containers import Horizontal from textual.screen import Screen from textual.widgets import Footer, Static from ..daemon.protocol.communicator import AsyncCommunicator from ..daemon.protocol.messages import ( StartTrainingImmediatelyPayload, StartTrainingPayload, TrainingStatusPayload, ) from .data_pipeline_pane import DataPipelinePane from .log_pane import StreamingLogPane from .training_widgets import TrainingScheduleWidget logger = logging.getLogger(__name__) class HeaderBar(Static): """Empty header bar.""" def compose(self) -> ComposeResult: return yield # unreachable, but makes the function a generator class JAXTrainingPane(Static): """Right pane showing JAX training status and metrics.""" def compose(self) -> ComposeResult: yield Static( "JAX Training Status\n\n" "Live training metrics will be displayed here when active:\n" "• Epoch Progress\n• Performance Metrics\n• Loss Values", classes="jax-training-content", ) class TrainingTuiApp(App): """Main TUI application for the training dashboard. This creates a full-screen interface with four main panes: - Header bar with uptime and status - Data pipeline pane (main/left) - Training status pane (right) - Log pane (bottom) """ CSS_PATH = "app.tcss" @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> None: """Adds all required command-line arguments to the given parser.""" parser.add_argument( "--config", required=True, help="Path to the training configuration file", ) parser.add_argument( "--logfile", help="Path to the log file for saving TUI output", ) parser.add_argument( "--io-dump", help="Path to file for dumping raw daemon IO for debugging", ) parser.add_argument( "--daemon-flag", action="append", default=[], dest="daemon_flags", help="Extra argument to pass to the daemon (repeatable)", ) _log_stream: TextReceiveStream _daemon_process: anyio.abc.Process _communicator: AsyncCommunicator _config_file: str _logfile: Optional[str] _data_pipeline_pane: DataPipelinePane _training_schedule_widget: TrainingScheduleWidget BINDINGS = [ ("q", "quit", "Quit"), ("ctrl+c", "quit", "Quit"), ] def __init__(self, args: Optional[argparse.Namespace] = None) -> None: """ Initializes the app. If 'args' is provided, it's used directly. If 'args' is None, fallback to parsing sys.argv. """ super().__init__() if args is None: # Fallback for when run by "textual run" parser = argparse.ArgumentParser() TrainingTuiApp.add_arguments(parser) args, _ = parser.parse_known_args() # Consume configuration from the args object self._config_file: str = args.config self._logfile: Optional[str] = args.logfile self._io_dump_file: Optional[str] = args.io_dump self._daemon_flags: list[str] = args.daemon_flags async def on_load(self) -> None: """Start the daemon process and communicator when the app loads.""" # Create the daemon process via Python module execution to avoid PATH reliance. env = None if "TF_CPP_MIN_LOG_LEVEL" not in os.environ: env = {**os.environ, "TF_CPP_MIN_LOG_LEVEL": "0"} self._daemon_process = await anyio.open_process( [ sys.executable, "-m", "lczero_training.commands.daemon", *self._daemon_flags, ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, ) assert self._daemon_process.stderr is not None assert self._daemon_process.stdin is not None assert self._daemon_process.stdout is not None # Set up streams and communicator self._log_stream = TextReceiveStream(self._daemon_process.stderr) io_dump = None if self._io_dump_file: io_dump = open(self._io_dump_file, "a", buffering=1) ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") io_dump.write(f"======= {ts} =======\n") self._communicator = AsyncCommunicator( handler=self, input_stream=TextReceiveStream(self._daemon_process.stdout), output_stream=TextSendStream(self._daemon_process.stdin), io_dump=io_dump, ) def compose(self) -> ComposeResult: """Compose the main UI layout.""" yield HeaderBar() self._data_pipeline_pane = DataPipelinePane() self._data_pipeline_pane.border_title = "Training data pipeline" yield self._data_pipeline_pane # Horizontal split below the data pipeline pane with Horizontal(id="training-status-container"): self._training_schedule_widget = TrainingScheduleWidget() self._training_schedule_widget.border_title = "Training Schedule" yield self._training_schedule_widget jax_training_pane = JAXTrainingPane() jax_training_pane.border_title = "JAX Training Status" yield jax_training_pane yield StreamingLogPane( stream=self._log_stream, logfile_path=self._logfile ) yield Footer() async def _monitor_daemon_process(self) -> None: """Notify and log when the daemon process exits unexpectedly.""" await self._daemon_process.wait() rc = self._daemon_process.returncode sig = -rc if rc is not None and rc < 0 else None if sig is not None: msg = f"Daemon killed by signal {sig} ({signal.Signals(sig).name})" else: msg = f"Daemon exited with code {rc}" logger.warning(msg) self.notify(msg, severity="warning", timeout=60) if self._logfile: with open(self._logfile, "a") as f: f.write(f"{msg}\n") def on_mount(self) -> None: """Start the communicator when the app mounts.""" self.run_worker(self._communicator.run(), exclusive=True) self.run_worker(self._send_start_training(), exclusive=False) self.run_worker(self._monitor_daemon_process(), exclusive=False) async def _send_start_training(self) -> None: """Send StartTrainingPayload with the config file.""" payload = StartTrainingPayload(config_filepath=self._config_file) await self._communicator.send(payload) async def _command_start_training_immediately(self) -> None: """Trigger immediate training without waiting for additional chunks.""" payload = StartTrainingImmediatelyPayload() await self._communicator.send(payload) self.notify("Requested immediate training start.") async def action_quit(self) -> None: # type: ignore """Handle quit action.""" self._daemon_process.send_signal(signal.SIGINT) with anyio.move_on_after(20) as scope: await self._daemon_process.wait() if scope.cancelled_caught: self._daemon_process.terminate() self.exit() def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]: """Add application-specific commands to the command palette.""" yield from super().get_system_commands(screen) yield SystemCommand( "Start training immediately", "Begin the next training cycle without waiting for chunks.", self._command_start_training_immediately, ) async def on_training_status(self, payload: TrainingStatusPayload) -> None: """Handle training status updates.""" self._data_pipeline_pane.update_metrics( payload.dataloader_1_second, payload.dataloader_total ) # Update training schedule widget self._training_schedule_widget.update_training_schedule( payload.training_schedule ) ================================================ FILE: src/lczero_training/tui/app.tcss ================================================ Screen { background: $primary; } HeaderBar { dock: top; height: 1; background: $primary-darken-1; color: $text; } DataPipelinePane { background: $surface; color: $text; border: none; margin: 1; padding: 0 1; layout: vertical; overflow-y: auto; height: 27; } #training-status-container { height: 12; layout: horizontal; } TrainingScheduleWidget { background: $surface; color: $text; border: solid $primary; margin: 1; width: 1fr; } JAXTrainingPane { background: $surface; color: $text; border: solid $primary; margin: 1; width: 1fr; } RichLog { height: 1fr; background: $panel; color: $text; margin: 0; width: 100%; overflow-y: scroll; } .stage-content, .queue-content { text-align: left; } .dataloader-row { layout: grid; grid-size: 6 8; grid-columns: 1fr 1fr 1fr 1fr 1fr 1fr; grid-rows: 1; grid-gutter: 0; border: none; padding: 0; margin: 0; height: auto; min-height: 1; width: 100%; } .stage-row { background: $surface; color: $text; grid-size: 5 8; grid-columns: 1fr 1fr 1fr 1fr 1fr; } .queue-row { background: $surface-darken-2; color: $text; grid-size: 6 1; } .statistics-row { background: $surface-darken-1; color: $text; grid-size: 1 1; } .statistics-full { background: $surface-darken-1; color: $text; padding: 0 1; margin: 0; height: auto; min-height: 1; width: 100%; } .row-label { padding: 0 1 0 0; margin: 0; height: 1; max-height: 1; text-style: bold; color: $text; content-align: left middle; text-wrap: nowrap; overflow: hidden; text-overflow: ellipsis; } .grid-spacer { width: 1; height: 1; max-height: 1; background: transparent; } .metric-chip { background: $surface-darken-1; color: $text; padding: 0 1; margin: 0 1 0 0; height: 1; max-height: 1; border: none; content-align: center middle; width: auto; min-width: 10; text-wrap: nowrap; overflow: hidden; } .load-chip { background: $primary-darken-2; color: $text; } .info-chip { background: $surface-darken-1; color: $text; } .warning-chip { background: $warning; color: $text; } .queue-name-chip { background: transparent; color: $text-muted; padding: 0; margin-right: 1; height: 1; max-height: 1; text-wrap: nowrap; overflow: hidden; } .queue-fill { height: 1; max-height: 1; width: 1fr; margin-right: 1; } .queue-fill-text { background: transparent; color: $text; margin: 0; padding: 0; height: 1; max-height: 1; text-wrap: nowrap; overflow: hidden; width: auto; } .queue-rate--zero { color: $error; } .queue-rate, .queue-total { text-wrap: nowrap; } .jax-training-content { padding: 1; } /* Training widgets styles */ .time-label { width: auto; padding: 0; margin-right: 1; text-align: left; } .time-progress-bar { width: 1fr; height: 1; } .time-progress-bar .bar--indeterminate { background: $panel; } .time-progress-bar .bar--bar { background: $panel; } .time-ratio { width: auto; padding: 0; text-align: right; } TimeProgressWidget { height: 1; layout: horizontal; } ChunksProgressWidget { height: 1; layout: horizontal; } #uptime-stage-display, #epochs-display { height: 1; } .log-content { padding: 0; } ================================================ FILE: src/lczero_training/tui/data_pipeline_pane.py ================================================ # ABOUTME: Data pipeline pane widget for displaying DataLoader metrics. # ABOUTME: Shows a grid of pipeline stages and queues with their metrics. from collections.abc import Iterable from typing import Any from textual.app import ComposeResult from textual.containers import Container import proto.training_metrics_pb2 as training_metrics_pb2 from .dataloader_widgets import QueueWidget, StageWidget, StatisticsRowWidget FRIENDLY_STAGE_NAMES = { "file_path_provider": "File discovery", "chunk_source_loader": "Chunk source loader", "shuffling_chunk_pool": "Shuffling chunk pool", "chunk_rescorer": "Chunk rescorer", "chunk_splitter": "Chunk splitter", "chunk_unpacker": "Chunk unpacker", "shuffling_frame_sampler": "Shuffling frame sampler", "tensor_generator": "Batched tensor generator", } class DataPipelinePane(Container): """Main pane showing data pipeline flow and statistics as a grid.""" def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._stage_widgets: dict[str, StageWidget] = {} self._queue_widgets: dict[str, dict[str, QueueWidget]] = {} self._queue_order: dict[str, list[str]] = {} self._statistics_widgets: dict[str, dict[str, StatisticsRowWidget]] = {} self._statistics_order: dict[str, list[str]] = {} self._stage_order: list[str] = [] def compose(self) -> ComposeResult: """The pane starts empty and rows are added when metrics arrive.""" yield from () def _friendly_title(self, stage_key: str) -> str: return FRIENDLY_STAGE_NAMES.get( stage_key, stage_key.replace("_", " ").title() ) def _ensure_stage_widget( self, stage_key: str, ) -> tuple[StageWidget, bool]: created = False if stage_key not in self._stage_widgets: stage_widget = StageWidget( stage_key=stage_key, fallback_name=self._friendly_title(stage_key), ) self._stage_widgets[stage_key] = stage_widget self._queue_widgets[stage_key] = {} self._queue_order[stage_key] = [] self._stage_order.append(stage_key) created = True else: stage_widget = self._stage_widgets[stage_key] return stage_widget, created def _ensure_queue_widgets( self, stage_key: str, stage_metric: training_metrics_pb2.StageMetricProto, ) -> list[QueueWidget]: stage_queue_widgets = self._queue_widgets.setdefault(stage_key, {}) queue_order = self._queue_order.setdefault(stage_key, []) new_widgets: list[QueueWidget] = [] friendly = self._friendly_title(stage_key) for index, queue_metric in enumerate(stage_metric.queue_metrics): queue_identifier = queue_metric.name or f"__index__{index}" if queue_identifier in stage_queue_widgets: continue label_suffix = ( f" {queue_metric.name}" if queue_metric.name else f" #{index + 1}" ) queue_widget = QueueWidget( stage_key=stage_key, stage_name=f"{friendly} queue{label_suffix}", queue_name=queue_identifier, ) stage_queue_widgets[queue_identifier] = queue_widget queue_order.append(queue_identifier) new_widgets.append(queue_widget) return new_widgets def _ensure_statistics_widgets( self, stage_key: str, stage_metric: training_metrics_pb2.StageMetricProto, ) -> list[StatisticsRowWidget]: stage_statistics_widgets = self._statistics_widgets.setdefault( stage_key, {} ) statistics_order = self._statistics_order.setdefault(stage_key, []) new_widgets: list[StatisticsRowWidget] = [] for statistics_metric in stage_metric.statistics_metrics: metric_name = statistics_metric.name or "" if metric_name in stage_statistics_widgets: continue label = metric_name.replace("_", " ").title() statistics_widget = StatisticsRowWidget( stage_key=stage_key, metric_name=metric_name, label=label, ) stage_statistics_widgets[metric_name] = statistics_widget statistics_order.append(metric_name) new_widgets.append(statistics_widget) return new_widgets def _mount_widgets( self, widgets: Iterable[StageWidget | QueueWidget | StatisticsRowWidget] ) -> None: async def _do_mount() -> None: await self.mount(*widgets) self.call_later(_do_mount) def _ensure_rows( self, metrics: training_metrics_pb2.DataLoaderMetricsProto, ) -> None: new_widgets: list[StageWidget | QueueWidget | StatisticsRowWidget] = [] for stage_metric in metrics.stage_metrics: stage_key = stage_metric.name if not stage_key: continue stage_widget, created = self._ensure_stage_widget(stage_key) if created: new_widgets.append(stage_widget) statistics_widgets = self._ensure_statistics_widgets( stage_key, stage_metric ) new_widgets.extend(statistics_widgets) queue_widgets = self._ensure_queue_widgets(stage_key, stage_metric) new_widgets.extend(queue_widgets) if new_widgets: self._mount_widgets(new_widgets) def update_metrics( self, dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None, dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None, ) -> None: """Update all pipeline stages and queues with new metrics.""" metrics_for_layout = dataloader_total or dataloader_1_second if metrics_for_layout: self._ensure_rows(metrics_for_layout) for stage_key in self._stage_order: stage_widget = self._stage_widgets.get(stage_key) if stage_widget: stage_widget.update_metrics( dataloader_1_second, dataloader_total ) for statistics_key in self._statistics_order.get(stage_key, []): statistics_widget = self._statistics_widgets[stage_key].get( statistics_key ) if statistics_widget: statistics_widget.update_metrics( dataloader_1_second, dataloader_total ) for queue_key in self._queue_order.get(stage_key, []): queue_widget = self._queue_widgets[stage_key].get(queue_key) if queue_widget: queue_widget.update_metrics( dataloader_1_second, dataloader_total ) ================================================ FILE: src/lczero_training/tui/dataloader_widgets.py ================================================ """Widgets that render data loader metrics without stage-specific logic.""" from __future__ import annotations from typing import Any, Dict from textual.app import ComposeResult from textual.widget import Widget from textual.widgets import ProgressBar, Static import proto.training_metrics_pb2 as training_metrics_pb2 def _find_stage_metric( metrics: training_metrics_pb2.DataLoaderMetricsProto | None, stage_key: str, ) -> training_metrics_pb2.StageMetricProto | None: if not metrics: return None for stage_metric in metrics.stage_metrics: if stage_metric.name == stage_key: return stage_metric return None def _collect_metric_names( stage_1s: training_metrics_pb2.StageMetricProto | None, stage_total: training_metrics_pb2.StageMetricProto | None, attribute: str, ) -> list[str]: names: list[str] = [] def _add_from( stage_metric: training_metrics_pb2.StageMetricProto | None, ) -> None: if not stage_metric: return for metric in getattr(stage_metric, attribute): name = metric.name if metric.name else "" if name not in names: names.append(name) _add_from(stage_total) _add_from(stage_1s) return names def _find_load_metric( stage_metric: training_metrics_pb2.StageMetricProto | None, metric_name: str, ) -> training_metrics_pb2.LoadMetricProto | None: if not stage_metric: return None for load_metric in stage_metric.load_metrics: if (load_metric.name or "") == metric_name: return load_metric return None def _find_count_metric( stage_metric: training_metrics_pb2.StageMetricProto | None, metric_name: str, ) -> training_metrics_pb2.CountMetricProto | None: if not stage_metric: return None for count_metric in stage_metric.count_metrics: if (count_metric.name or "") == metric_name: return count_metric return None def _find_gauge_metric( stage_metric: training_metrics_pb2.StageMetricProto | None, metric_name: str, ) -> training_metrics_pb2.GaugeMetricProto | None: if not stage_metric: return None for gauge_metric in stage_metric.gauge_metrics: if (gauge_metric.name or "") == metric_name: return gauge_metric return None def _find_statistics_metric( stage_metric: training_metrics_pb2.StageMetricProto | None, metric_name: str, ) -> training_metrics_pb2.StatisticsProtoDouble | None: if not stage_metric: return None for stats_metric in stage_metric.statistics_metrics: if (stats_metric.name or "") == metric_name: return stats_metric return None def _get_queue_metric( stage_metric: training_metrics_pb2.StageMetricProto | None, queue_name: str | None, ) -> training_metrics_pb2.QueueMetricProto | None: if not stage_metric or not stage_metric.queue_metrics: return None if queue_name is None: return stage_metric.queue_metrics[0] if queue_name.startswith("__index__"): try: index = int(queue_name.removeprefix("__index__")) except ValueError: index = -1 if 0 <= index < len(stage_metric.queue_metrics): return stage_metric.queue_metrics[index] for queue_metric in stage_metric.queue_metrics: if (queue_metric.name or "") == queue_name: return queue_metric return None def format_si(value: int, precision: int = 1) -> str: if value == 0: return "0" units = [ (1_000_000_000_000, "T"), (1_000_000_000, "G"), (1_000_000, "M"), (1_000, "k"), ] for threshold, unit in units: if value >= threshold: result = value / threshold if precision == 0: return f"{int(result)}{unit}" return f"{result:.{precision}f}{unit}".rstrip("0").rstrip(".") return str(value) def format_full_number(value: int) -> str: if value < 10_000: return str(value) return f"{value:_}".replace("_", "'") def _format_load( load_metric: training_metrics_pb2.LoadMetricProto | None, label: str, ) -> str: if not load_metric: return f"{label} --" total_part = ( f"{load_metric.total_seconds:.0f}" if load_metric.total_seconds > 0 else "--" ) return f"{label} {load_metric.load_seconds:.1f}/{total_part}s" def _format_count( count_metric_1s: training_metrics_pb2.CountMetricProto | None, count_metric_total: training_metrics_pb2.CountMetricProto | None, label: str, ) -> str: if count_metric_1s and count_metric_total: rate = format_si(count_metric_1s.count) total = format_full_number(count_metric_total.count) return f"{label} {rate}/s ({total} total)" if count_metric_total: return f"{label} {format_full_number(count_metric_total.count)}" if count_metric_1s: return f"{label} {format_si(count_metric_1s.count)}/s" return f"{label} --" def _format_gauge( gauge_metric: training_metrics_pb2.GaugeMetricProto | None, label: str, ) -> str: if not gauge_metric: return f"{label} --" if gauge_metric.HasField("capacity"): value_text = format_full_number(gauge_metric.value) capacity_text = format_full_number(gauge_metric.capacity) return f"{label} {value_text}/{capacity_text}" return f"{label} {format_full_number(gauge_metric.value)}" def _format_statistics( stats_1s: training_metrics_pb2.StatisticsProtoDouble | None, stats_total: training_metrics_pb2.StatisticsProtoDouble | None, label: str, ) -> str: parts = [label] if stats_1s and stats_1s.count > 0: avg_1s = stats_1s.sum / stats_1s.count parts.append( f"per 1s: avg {avg_1s:.1f} (min {stats_1s.min:.1f}, max {stats_1s.max:.1f}, count {format_si(stats_1s.count)})" ) if stats_total and stats_total.count > 0: avg_total = stats_total.sum / stats_total.count parts.append( f"total: avg {avg_total:.1f} (min {stats_total.min:.1f}, max {stats_total.max:.1f}, count {format_full_number(stats_total.count)})" ) if len(parts) == 1: return f"{label} --" return "; ".join(parts) def _average_queue_fullness( queue_metric: training_metrics_pb2.QueueMetricProto | None, ) -> int | None: if not queue_metric: return None if ( queue_metric.HasField("queue_fullness") and queue_metric.queue_fullness.count > 0 ): return int( queue_metric.queue_fullness.sum / queue_metric.queue_fullness.count ) return None def _canonical_stage_name( stage_metric: training_metrics_pb2.StageMetricProto | None, fallback: str | None, stage_key: str | None, ) -> str: if stage_metric and stage_metric.name: return stage_metric.name if fallback: return fallback if stage_key: return stage_key return "--" class BaseRowWidget(Widget): """Base class for pipeline rows that renders a name and content widgets.""" MAX_GRID_ROWS = 8 # Supports up to 28 chips (4 per row, 7 content rows) def __init__( self, stage_key: str | None = None, fallback_name: str | None = None, row_type: str = "stage-row", **kwargs: Any, ) -> None: classes = f"dataloader-row {row_type}" super().__init__(classes=classes, **kwargs) self.stage_key = stage_key self._fallback_name = fallback_name self._name_label = Static( _canonical_stage_name(None, fallback_name, stage_key), classes="row-label", ) self._content_widgets: list[Widget] = [] self._spacers: list[Static] = [] def compose(self) -> ComposeResult: # Name label goes in first cell (row 0, col 0) yield self._name_label def on_mount(self) -> None: if self._content_widgets: async def _mount_initial() -> None: # Mount chips with spacers interleaved at the start of each row for i, widget in enumerate(self._content_widgets): # After every 4 chips, add a spacer for the next row's column 0 if i > 0 and i % 4 == 0: spacer = Static("", classes="grid-spacer") self._spacers.append(spacer) await self.mount(spacer) await self.mount(widget) self.call_later(_mount_initial) def add_content_widget(self, widget: Widget) -> None: if widget in self._content_widgets: return self._content_widgets.append(widget) def _update_name( self, stage_metric: training_metrics_pb2.StageMetricProto | None, ) -> None: self._name_label.update( _canonical_stage_name( stage_metric, self._fallback_name, self.stage_key ) ) class StageWidget(BaseRowWidget): """Row widget that renders all metrics exposed by a stage.""" def __init__( self, stage_key: str, fallback_name: str | None = None, **kwargs: Any, ) -> None: super().__init__( stage_key=stage_key, fallback_name=fallback_name, row_type="stage-row", **kwargs, ) self._chips: Dict[str, Static] = {} def _ensure_chip(self, key: str, default_text: str, classes: str) -> Static: chip = self._chips.get(key) if chip is None: chip = Static(default_text, classes=f"metric-chip {classes}") self._chips[key] = chip self.add_content_widget(chip) return chip def _update_last_chunk_chip( self, stage_1s: training_metrics_pb2.StageMetricProto | None, stage_total: training_metrics_pb2.StageMetricProto | None, ) -> None: stage = None if stage_1s and stage_1s.HasField("last_chunk_key"): stage = stage_1s elif stage_total and stage_total.HasField("last_chunk_key"): stage = stage_total if not stage: return last_value = stage.last_chunk_key or "--" chip = self._ensure_chip("info:last", "last --", "info-chip") chip.update(f"last {last_value}") def _update_anchor_chip( self, stage_total: training_metrics_pb2.StageMetricProto | None, ) -> None: if not stage_total or not stage_total.HasField("anchor"): return chip = self._ensure_chip("info:anchor", "anchor --", "info-chip") chip.update(f"anchor {stage_total.anchor}") def update_metrics( self, dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None, dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None, ) -> None: if self.stage_key is None: return stage_metric_1s = _find_stage_metric( dataloader_1_second, self.stage_key ) stage_metric_total = _find_stage_metric( dataloader_total, self.stage_key ) self._update_name(stage_metric_1s or stage_metric_total) load_names = _collect_metric_names( stage_metric_1s, stage_metric_total, "load_metrics" ) for load_name in load_names: label = load_name or "load" load_metric = _find_load_metric(stage_metric_1s, load_name) if load_metric is None: load_metric = _find_load_metric(stage_metric_total, load_name) chip = self._ensure_chip( f"load:{load_name}", f"{label} --", "load-chip" ) chip.update(_format_load(load_metric, label=label)) count_names = _collect_metric_names( stage_metric_1s, stage_metric_total, "count_metrics" ) for count_name in count_names: label = count_name or "count" count_metric_1s = _find_count_metric(stage_metric_1s, count_name) count_metric_total = _find_count_metric( stage_metric_total, count_name ) chip = self._ensure_chip( f"count:{count_name}", f"{label} --", "info-chip" ) chip.update( _format_count(count_metric_1s, count_metric_total, label=label) ) gauge_names = _collect_metric_names( stage_metric_1s, stage_metric_total, "gauge_metrics" ) for gauge_name in gauge_names: label = gauge_name or "gauge" gauge_metric = _find_gauge_metric(stage_metric_total, gauge_name) if gauge_metric is None: gauge_metric = _find_gauge_metric(stage_metric_1s, gauge_name) chip = self._ensure_chip( f"gauge:{gauge_name}", f"{label} --", "info-chip" ) chip.update(_format_gauge(gauge_metric, label=label)) self._update_last_chunk_chip(stage_metric_1s, stage_metric_total) self._update_anchor_chip(stage_metric_total) class StatisticsRowWidget(BaseRowWidget): """Full-width row for statistics metrics.""" def __init__( self, stage_key: str, metric_name: str, label: str, **kwargs: Any, ) -> None: super().__init__( stage_key=stage_key, fallback_name=None, row_type="statistics-row", **kwargs, ) self._metric_name = metric_name self._label = label self._stats_label = Static(f"{label} --", classes="statistics-full") self.add_content_widget(self._stats_label) def compose(self) -> ComposeResult: return yield def on_mount(self) -> None: if self._content_widgets: async def _mount_initial() -> None: await self.mount(*self._content_widgets) self.call_later(_mount_initial) def update_metrics( self, dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None, dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None, ) -> None: if not self.stage_key: self._stats_label.update(f"{self._label} --") return stage_1s = _find_stage_metric(dataloader_1_second, self.stage_key) stage_total = _find_stage_metric(dataloader_total, self.stage_key) stats_1s = _find_statistics_metric(stage_1s, self._metric_name) stats_total = _find_statistics_metric(stage_total, self._metric_name) self._stats_label.update( _format_statistics(stats_1s, stats_total, self._label) ) class QueueWidget(BaseRowWidget): """Row widget for queue metrics between stages.""" def __init__( self, stage_key: str | None = None, stage_name: str | None = None, queue_name: str | None = None, **kwargs: Any, ) -> None: super().__init__( stage_key=stage_key, fallback_name=stage_name, row_type="queue-row", **kwargs, ) self._queue_name = queue_name self._queue_name_chip = Static( "queue --", classes="metric-chip queue-name-chip" ) self._rate_chip = Static("rate --/s", classes="metric-chip queue-rate") self._total_chip = Static("total --", classes="metric-chip queue-total") self._drop_chip = Static("dropped --", classes="metric-chip") self._fill_bar = ProgressBar( classes="queue-fill", show_percentage=False, show_eta=False, ) self._fill_text = Static("--/--", classes="metric-chip queue-fill-text") self.add_content_widget(self._queue_name_chip) self.add_content_widget(self._rate_chip) self.add_content_widget(self._total_chip) self.add_content_widget(self._drop_chip) self.add_content_widget(self._fill_bar) self.add_content_widget(self._fill_text) def compose(self) -> ComposeResult: # Queue widgets don't show the name label - use all 6 columns return yield # Make this a generator def on_mount(self) -> None: # Mount all content widgets without spacers (use all 6 columns per row) if self._content_widgets: async def _mount_initial() -> None: await self.mount(*self._content_widgets) self.call_later(_mount_initial) def update_metrics( self, dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None, dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None, ) -> None: if not self.stage_key: self._queue_name_chip.update("queue --") self._rate_chip.update("rate --/s") self._total_chip.update("total --") self._drop_chip.update("dropped --") self._drop_chip.remove_class("warning-chip") self._fill_bar.total = 1 self._fill_bar.progress = 0 self._fill_text.update("--/--") return stage_1sec = _find_stage_metric(dataloader_1_second, self.stage_key) stage_total = _find_stage_metric(dataloader_total, self.stage_key) self._update_name(stage_1sec or stage_total) queue_1sec = _get_queue_metric(stage_1sec, self._queue_name) queue_total = _get_queue_metric(stage_total, self._queue_name) queue_name = None if queue_1sec and queue_1sec.name: queue_name = queue_1sec.name elif queue_total and queue_total.name: queue_name = queue_total.name elif self._queue_name: queue_name = self._queue_name self._queue_name_chip.update( f"queue {queue_name}" if queue_name else "queue --" ) rate = queue_1sec.get_count if queue_1sec else 0 self._rate_chip.update(f"rate {format_si(rate)}/s") if rate == 0: self._rate_chip.add_class("queue-rate--zero") else: self._rate_chip.remove_class("queue-rate--zero") if queue_total: self._total_chip.update( f"total {format_full_number(queue_total.get_count)}" ) else: self._total_chip.update("total --") has_drops = False if queue_1sec and queue_1sec.drop_count: self._drop_chip.update( f"dropped {format_si(queue_1sec.drop_count)}/s" ) has_drops = True elif queue_total and queue_total.drop_count: self._drop_chip.update( f"dropped {format_full_number(queue_total.drop_count)}" ) has_drops = True else: self._drop_chip.update("dropped --") has_drops = False if has_drops: self._drop_chip.add_class("warning-chip") else: self._drop_chip.remove_class("warning-chip") size = _average_queue_fullness(queue_1sec) capacity: int | None = None if queue_1sec and queue_1sec.queue_capacity > 0: capacity = queue_1sec.queue_capacity elif queue_total and queue_total.queue_capacity > 0: capacity = queue_total.queue_capacity if capacity and capacity > 0: self._fill_bar.total = capacity if size is not None: self._fill_bar.progress = min(size, capacity) fill_text = ( f"{format_full_number(size)}/{format_full_number(capacity)}" ) else: self._fill_bar.progress = 0 fill_text = f"--/{format_full_number(capacity)}" else: self._fill_bar.total = 1 self._fill_bar.progress = 0 fill_text = "--/--" self._fill_text.update(fill_text) ================================================ FILE: src/lczero_training/tui/log_pane.py ================================================ import datetime from pathlib import Path from typing import Any, Optional, TextIO from anyio.streams.text import TextReceiveStream from textual.widgets import RichLog class StreamingLogPane(RichLog): """Log pane that streams output from an async text stream.""" def __init__( self, stream: TextReceiveStream, logfile_path: Optional[str] = None, **kwargs: Any, ) -> None: super().__init__( highlight=True, markup=True, max_lines=1000, wrap=True, **kwargs ) self._stream = stream self._logfile_path = logfile_path self._logfile: Optional[Path] = None self._logfile_handle: Optional[TextIO] = None if logfile_path: self._logfile = Path(logfile_path) def on_mount(self) -> None: """Start the async reading task when the widget is mounted.""" if self._logfile: self._write_banner() self.run_worker(self._read_stream()) def _write_banner(self) -> None: """Write a session banner to the logfile.""" if not self._logfile: return self._logfile.parent.mkdir(parents=True, exist_ok=True) self._logfile_handle = self._logfile.open("a", encoding="utf-8") timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") banner = ( f"\n{'=' * 80}\n" f"LCZero Training TUI Session Started: {timestamp}\n" f"{'=' * 80}\n" ) self._logfile_handle.write(banner) self._logfile_handle.flush() def _write_to_file(self, line: str) -> None: """Write a line to the logfile.""" if not self._logfile_handle: return self._logfile_handle.write(f"{line}\n") self._logfile_handle.flush() async def _read_stream(self) -> None: """Async function that reads lines from the text stream.""" try: async for line in self._stream: line = line.strip() if line: self.write(line) self._write_to_file(line) except Exception: pass ================================================ FILE: src/lczero_training/tui/training_widgets.py ================================================ from textual.app import ComposeResult from textual.widgets import ProgressBar, Static from ..daemon.protocol.messages import TrainingScheduleData class TimeProgressWidget(Static): """A widget to display a label, progress bar, and time ratio.""" def __init__(self, label: str, *, id: str | None = None) -> None: super().__init__(id=id) self._label = label def compose(self) -> ComposeResult: yield Static(self._label, classes="time-label") yield ProgressBar(show_eta=False, classes="time-progress-bar") yield Static("", classes="time-ratio") def update_progress( self, current: float, total: float, current_formatted: str | None = None, total_formatted: str | None = None, ) -> None: """Update the progress bar and ratio text.""" progress_bar = self.query_one(ProgressBar) progress_bar.total = total or None progress_bar.progress = current current_str = ( current_formatted if current_formatted is not None else str(int(current)) ) total_str = ( total_formatted if total_formatted is not None else str(int(total)) ) self.query_one(".time-ratio", Static).update( f"{current_str}/{total_str}" ) def format_time_duration(seconds: float) -> str: """Format time duration in seconds to human readable format with days support.""" if seconds <= 0: return "--" total_seconds = int(seconds) days, remainder = divmod(total_seconds, 86400) hours, remainder = divmod(remainder, 3600) minutes, secs = divmod(remainder, 60) if days > 0: return f"{days}d {hours:02d}:{minutes:02d}:{secs:02d}" if hours > 0: return f"{hours:02d}:{minutes:02d}:{secs:02d}" return f"{minutes:02d}:{secs:02d}" class TrainingScheduleWidget(Static): """A widget to display training schedule information.""" def compose(self) -> ComposeResult: yield Static("Uptime: -- Stage: --", id="uptime-stage-display") yield Static("Completed epochs: 0", id="epochs-display") yield TimeProgressWidget("New Chunks:", id="chunks-progress") yield TimeProgressWidget("Training time", id="training-time-progress") yield TimeProgressWidget("Cycle time", id="cycle-time-progress") def update_training_schedule( self, data: TrainingScheduleData | None ) -> None: """Update the widget with new training schedule data.""" if not data: return uptime_str = format_time_duration(data.total_uptime_seconds) self.query_one("#uptime-stage-display", Static).update( f"Uptime: {uptime_str} Stage: {data.current_stage.value}" ) self.query_one("#epochs-display", Static).update( f"Completed epochs: {data.completed_epochs_since_start}" ) self.query_one("#chunks-progress", TimeProgressWidget).update_progress( current=data.new_chunks_since_training_start, total=data.chunks_to_wait, ) self.query_one( "#training-time-progress", TimeProgressWidget ).update_progress( current=data.current_training_time_seconds, total=data.previous_training_time_seconds, current_formatted=format_time_duration( data.current_training_time_seconds ), total_formatted=format_time_duration( data.previous_training_time_seconds ), ) self.query_one( "#cycle-time-progress", TimeProgressWidget ).update_progress( current=data.current_cycle_time_seconds, total=data.previous_cycle_time_seconds, current_formatted=format_time_duration( data.current_cycle_time_seconds ), total_formatted=format_time_duration( data.previous_cycle_time_seconds ), ) ================================================ FILE: src/proto/__init__.py ================================================ ================================================ FILE: tf/attention_policy_map.py ================================================ import numpy as np move = np.arange(1, 8) diag = np.array([ move + move*8, move - move*8, move*-1 - move*8, move*-1 + move*8 ]) orthog = np.array([ move, move*-8, move*-1, move*8 ]) knight = np.array([ [2 + 1*8], [2 - 1*8], [1 - 2*8], [-1 - 2*8], [-2 - 1*8], [-2 + 1*8], [-1 + 2*8], [1 + 2*8] ]) promos = np.array([2*8, 3*8, 4*8]) pawn_promotion = np.array([ -1 + promos, 0 + promos, 1 + promos ]) def make_map(): """theoretically possible put-down squares (numpy array) for each pick-up square (list element). squares are [0, 1, ..., 63] for [a1, b1, ..., h8]. squares after 63 are promotion squares. each successive "row" beyond 63 (ie. 64:72, 72:80, 80:88) are for over-promotions to queen, rook, and bishop; respectively. a pawn traverse to row 56:64 signifies a "default" promotion to a knight.""" traversable = [] for i in range(8): for j in range(8): sq = (8*i + j) traversable.append( sq + np.sort( np.int32( np.concatenate(( orthog[0][:7-j], orthog[2][:j], orthog[1][:i], orthog[3][:7-i], diag[0][:np.min((7-i, 7-j))], diag[3][:np.min((7-i, j))], diag[1][:np.min((i, 7-j))], diag[2][:np.min((i, j))], knight[0] if i < 7 and j < 6 else [], knight[1] if i > 0 and j < 6 else [], knight[2] if i > 1 and j < 7 else [], knight[3] if i > 1 and j > 0 else [], knight[4] if i > 0 and j > 1 else [], knight[5] if i < 7 and j > 1 else [], knight[6] if i < 6 and j > 0 else [], knight[7] if i < 6 and j < 7 else [], pawn_promotion[0] if i == 6 and j > 0 else [], pawn_promotion[1] if i == 6 else [], pawn_promotion[2] if i == 6 and j < 7 else [], )) ) ) ) z = np.zeros((64*64+8*24, 1858), dtype=np.int32) # first loop for standard moves (for i in 0:1858, stride by 1) i = 0 for pickup_index, putdown_indices in enumerate(traversable): for putdown_index in putdown_indices: if putdown_index < 64: z[putdown_index + (64*pickup_index), i] = 1 i += 1 # second loop for promotions (for i in 1792:1858, stride by ls[j]) j = 0 j1 = np.array([3, -2, 3, -2, 3]) j2 = np.array([3, 3, -5, 3, 3, -5, 3, 3, 1]) ls = np.append(j1, 1) for k in range(6): ls = np.append(ls, j2) ls = np.append(ls, j1) ls = np.append(ls, 0) for pickup_index, putdown_indices in enumerate(traversable): for putdown_index in putdown_indices: if putdown_index >= 64: pickup_file = pickup_index % 8 promotion_file = putdown_index % 8 promotion_rank = (putdown_index // 8) - 8 z[4096 + pickup_file*24 + (promotion_file*3+promotion_rank), i] = 1 i += ls[j] j += 1 return z def make_pos_enc(): traversable = [] for i in range(8): for j in range(8): sq = (8*i + j) traversable.append( sq + np.sort( np.int32( np.concatenate(( orthog[0][:7-j], orthog[2][:j], orthog[1][:i], orthog[3][:7-i], diag[0][:np.min((7-i, 7-j))], diag[3][:np.min((7-i, j))], diag[1][:np.min((i, 7-j))], diag[2][:np.min((i, j))], knight[0] if i < 7 and j < 6 else [], knight[1] if i > 0 and j < 6 else [], knight[2] if i > 1 and j < 7 else [], knight[3] if i > 1 and j > 0 else [], knight[4] if i > 0 and j > 1 else [], knight[5] if i < 7 and j > 1 else [], knight[6] if i < 6 and j > 0 else [], knight[7] if i < 6 and j < 7 else [], # pawn_promotion[0] if i == 6 and j > 0 else [], # pawn_promotion[1] if i == 6 else [], # pawn_promotion[2] if i == 6 and j < 7 else [], )) ) ) ) # pos_enc = np.zeros((1, 64, 88), dtype=np.float32) pos_enc = np.zeros((1, 64, 64), dtype=np.float32) for i, k in enumerate(traversable): pos_enc[0][i][i] = -1. for j in k: pos_enc[0][i][j] = 1. return pos_enc ================================================ FILE: tf/chunkparsefunc.py ================================================ #!/usr/bin/env python3 # # This file is part of Leela Chess. # Copyright (C) 2021 Leela Chess Authors # # Leela Chess is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Leela Chess is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Leela Chess. If not, see . import tensorflow as tf def parse_function(planes, probs, winner, q, plies_left): """ Convert unpacked record batches to tensors for tensorflow training """ planes = tf.io.decode_raw(planes, tf.float32) probs = tf.io.decode_raw(probs, tf.float32) winner = tf.io.decode_raw(winner, tf.float32) q = tf.io.decode_raw(q, tf.float32) plies_left = tf.io.decode_raw(plies_left, tf.float32) planes = tf.reshape(planes, (-1, 112, 8, 8)) probs = tf.reshape(probs, (-1, 1858)) winner = tf.reshape(winner, (-1, 3)) q = tf.reshape(q, (-1, 3)) plies_left = tf.reshape(plies_left, (-1, 1)) return (planes, probs, winner, q, plies_left) ================================================ FILE: tf/chunkparser.py ================================================ #!/usr/bin/env python3 # # This file is part of Leela Chess. # Copyright (C) 2018 Folkert Huizinga # Copyright (C) 2017-2018 Gian-Carlo Pascutto # # Leela Chess is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Leela Chess is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Leela Chess. If not, see . """ General comments on how chunkparser works. A "training record" or just "record" is a fixed-length packed byte array. Typically records are generated during training and are stored together by game, one record for each position in the game, but this arrangement is not required. Over dev time additional fields have been added to the training record, most of which just put additional information after the end of the byte array used in the previous version. Currently supported training record versions are V3, V4, V5, and V6. shufflebuffer.ShuffleBuffer is a simple structure holding an array of training records that are efficiently randomized and replaced as needed. All records in ShuffleBuffer are adjusted to be the same number of bytes by appending unused bytes *before* being put in the shuffler. byte padding is done in chunkparser.ChunkParser.sample_record() sample_record() also skips most training records to avoid sampling over-correlated positions since they typically are from sequential positions in a game. Current implementation of "diff focus" also is in sample_record() and works by probabilistically skipping records according to how accurate the no-search eval ('orig_q') is compared to eval after search ('best_q') as well as the recorded policy_kld (a measure of difference between no search policy and the final policy distribution). It does not use draw values at this point. Putting diff focus here is efficient because it runs in parallel workers and peeks at the records without requiring any unpacking. The constructor for chunkparser.ChunkParser() sets a bunch of class constants and creates a fixed number of parallel Python multiprocessing.Pipe objects, which consist of a "reader" and a "writer". The writer(s) get data directly from training data files and write them into the pipe using the writer.send_bytes() method. The reader(s) get data out of the pipe using the reader.rev_bytes() method and feed them to the ShuffleBuffer using its insert_or_replace() method, which also handles the shuffling itself. Records come back out of the ShuffleBuffer (already a fixed byte number regardless of training version) using the multiplexed generators specified in the ChunkParser.parse() method. They are first recovered as raw byte records in the vX_gen() method (currently v6_gen), then converted to tuples of more interpretable data in the convert_vX_to_tuple() method and finally sent on to tensorflow in training batches by the batch_gen() method. """ import itertools import multiprocessing as mp import numpy as np import random import shufflebuffer as sb import struct import unittest import gzip from select import select V6_VERSION = struct.pack('i', 6) V5_VERSION = struct.pack('i', 5) CLASSICAL_INPUT = struct.pack('i', 1) V4_VERSION = struct.pack('i', 4) V3_VERSION = struct.pack('i', 3) V6_STRUCT_STRING = '4si7432s832sBBBBBBBbfffffffffffffffIHH4H' V5_STRUCT_STRING = '4si7432s832sBBBBBBBbfffffff' V4_STRUCT_STRING = '4s7432s832sBBBBBBBbffff' V3_STRUCT_STRING = '4s7432s832sBBBBBBBb' def reverse_expand_bits(plane): return np.unpackbits(np.array([plane], dtype=np.uint8))[::-1].astype( np.float32).tobytes() # Interface for a chunk data source. class ChunkDataSrc: def __init__(self, items): self.items = items def next(self): if not self.items: return None return self.items.pop() def chunk_reader(chunk_filenames, chunk_filename_queue): """ Reads chunk filenames from a list and writes them in shuffled order to output_pipes. """ chunks = [] done = chunk_filenames while True: if not chunks: chunks, done = done, chunks random.shuffle(chunks) if not chunks: print("chunk_reader didn't find any chunks.") return None while len(chunks): filename = chunks.pop() done.append(filename) chunk_filename_queue.put(filename) print("chunk_reader exiting.") return None class ChunkParser: def __init__(self, chunks, expected_input_format, shuffle_size=1, sample=1, buffer_size=1, batch_size=256, diff_focus_min=1, diff_focus_slope=0, diff_focus_q_weight=6.0, diff_focus_pol_scale=3.5, workers=None): self.inner = ChunkParserInner(self, chunks, expected_input_format, shuffle_size, sample, buffer_size, batch_size, diff_focus_min, diff_focus_slope, diff_focus_q_weight, diff_focus_pol_scale, workers) def shutdown(self): """ Terminates all the workers """ for i in range(len(self.processes)): self.processes[i].terminate() self.processes[i].join() self.inner.readers[i].close() self.inner.writers[i].close() self.chunk_process.terminate() self.chunk_process.join() def parse(self): return self.inner.parse() def sequential(self): return self.inner.sequential() class ChunkParserInner: def __init__(self, parent, chunks, expected_input_format, shuffle_size, sample, buffer_size, batch_size, diff_focus_min, diff_focus_slope, diff_focus_q_weight, diff_focus_pol_scale, workers): """ Read data and yield batches of raw tensors. 'parent' the outer chunk parser to store processes. Must not be stored by self directly or indirectly. 'chunks' list of chunk filenames. 'shuffle_size' is the size of the shuffle buffer. 'sample' is the rate to down-sample. 'diff_focus_min', 'diff_focus_slope', 'diff_focus_q_weight' and 'diff_focus_pol_scale' control diff focus 'workers' is the number of child workers to use. The data is represented in a number of formats through this dataflow pipeline. In order, they are: chunk: The name of a file containing chunkdata chunkdata: type Bytes. Multiple records of v6 format where each record consists of (state, policy, result, q) raw: A byte string holding raw tensors contenated together. This is used to pass data from the workers to the parent. Exists because TensorFlow doesn't have a fast way to unpack bit vectors. 7950 bytes long. """ self.expected_input_format = expected_input_format # Build 2 flat float32 planes with values 0,1 self.flat_planes = [] for i in range(2): self.flat_planes.append( (np.zeros(64, dtype=np.float32) + i).tobytes()) # set the down-sampling rate self.sample = sample # set the details for diff focus, defaults accept all positions self.diff_focus_min = diff_focus_min self.diff_focus_slope = diff_focus_slope self.diff_focus_q_weight = diff_focus_q_weight self.diff_focus_pol_scale = diff_focus_pol_scale # set the mini-batch size self.batch_size = batch_size # set number of elements in the shuffle buffer. self.shuffle_size = shuffle_size # Start worker processes, leave 2 for TensorFlow if workers is None: workers = max(1, mp.cpu_count() - 2) if workers > 0: print("Using {} worker processes.".format(workers)) # Start the child workers running self.readers = [] self.writers = [] parent.processes = [] self.chunk_filename_queue = mp.Queue(maxsize=4096) for _ in range(workers): read, write = mp.Pipe(duplex=False) p = mp.Process(target=self.task, args=(self.chunk_filename_queue, write)) p.daemon = True parent.processes.append(p) p.start() self.readers.append(read) self.writers.append(write) parent.chunk_process = mp.Process(target=chunk_reader, args=(chunks, self.chunk_filename_queue)) parent.chunk_process.daemon = True parent.chunk_process.start() else: self.chunks = chunks self.init_structs() def init_structs(self): """ struct.Struct doesn't pickle, so it needs to be separately constructed in workers. """ self.v6_struct = struct.Struct(V6_STRUCT_STRING) self.v5_struct = struct.Struct(V5_STRUCT_STRING) self.v4_struct = struct.Struct(V4_STRUCT_STRING) self.v3_struct = struct.Struct(V3_STRUCT_STRING) def convert_v6_to_tuple(self, content): """ Unpack a v6 binary record to 5-tuple (state, policy pi, result, q, m) v6 struct format is (8356 bytes total): size 1st byte index uint32_t version; 0 uint32_t input_format; 4 float probabilities[1858]; 7432 bytes 8 uint64_t planes[104]; 832 bytes 7440 uint8_t castling_us_ooo; 8272 uint8_t castling_us_oo; 8273 uint8_t castling_them_ooo; 8274 uint8_t castling_them_oo; 8275 uint8_t side_to_move_or_enpassant; 8276 uint8_t rule50_count; 8277 // Bitfield with the following allocation: // bit 7: side to move (input type 3) // bit 6: position marked for deletion by the rescorer (never set by lc0) // bit 5: game adjudicated (v6) // bit 4: max game length exceeded (v6) // bit 3: best_q is for proven best move (v6) // bit 2: transpose transform (input type 3) // bit 1: mirror transform (input type 3) // bit 0: flip transform (input type 3) uint8_t invariance_info; 8278 uint8_t dep_result; 8279 float root_q; 8280 float best_q; 8284 float root_d; 8288 float best_d; 8292 float root_m; // In plies. 8296 float best_m; // In plies. 8300 float plies_left; 8304 float result_q; 8308 float result_d; 8312 float played_q; 8316 float played_d; 8320 float played_m; 8324 // The folowing may be NaN if not found in cache. float orig_q; // For value repair. 8328 float orig_d; 8332 float orig_m; 8336 uint32_t visits; 8340 // Indices in the probabilities array. uint16_t played_idx; 8344 uint16_t best_idx; 8346 uint64_t reserved; 8348 """ # unpack the V6 content from raw byte array, arbitrarily chose 4 2-byte values # for the 8 "reserved" bytes (ver, input_format, probs, planes, us_ooo, us_oo, them_ooo, them_oo, stm, rule50_count, invariance_info, dep_result, root_q, best_q, root_d, best_d, root_m, best_m, plies_left, result_q, result_d, played_q, played_d, played_m, orig_q, orig_d, orig_m, visits, played_idx, best_idx, reserved1, reserved2, reserved3, reserved4) = self.v6_struct.unpack(content) """ v5 struct format was (8308 bytes total) int32 version (4 bytes) int32 input_format (4 bytes) 1858 float32 probabilities (7432 bytes) 104 (13*8) packed bit planes of 8 bytes each (832 bytes) uint8 castling us_ooo (1 byte) uint8 castling us_oo (1 byte) uint8 castling them_ooo (1 byte) uint8 castling them_oo (1 byte) uint8 side_to_move (1 byte) uint8 rule50_count (1 byte) uint8 dep_ply_count (1 byte) (unused) int8 result (1 byte) float32 root_q (4 bytes) float32 best_q (4 bytes) float32 root_d (4 bytes) float32 best_d (4 bytes) float32 root_m (4 bytes) float32 best_m (4 bytes) float32 plies_left (4 bytes) """ # v3/4 data sometimes has a useful value in dep_ply_count (now invariance_info), # so copy that over if the new ply_count is not populated. if plies_left == 0: plies_left = invariance_info plies_left = struct.pack('f', plies_left) assert input_format == self.expected_input_format # Unpack bit planes and cast to 32 bit float planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype( np.float32) rule50_divisor = 99.0 if input_format > 3: rule50_divisor = 100.0 rule50_plane = struct.pack('f', rule50_count / rule50_divisor) * 64 if input_format == 1: middle_planes = self.flat_planes[us_ooo] + \ self.flat_planes[us_oo] + \ self.flat_planes[them_ooo] + \ self.flat_planes[them_oo] + \ self.flat_planes[stm] elif input_format == 2: # Each inner array has to be reversed as these fields are in opposite endian to the planes data. them_ooo_bytes = reverse_expand_bits(them_ooo) us_ooo_bytes = reverse_expand_bits(us_ooo) them_oo_bytes = reverse_expand_bits(them_oo) us_oo_bytes = reverse_expand_bits(us_oo) middle_planes = us_ooo_bytes + (6*8*4) * b'\x00' + them_ooo_bytes + \ us_oo_bytes + (6*8*4) * b'\x00' + them_oo_bytes + \ self.flat_planes[0] + \ self.flat_planes[0] + \ self.flat_planes[stm] elif input_format == 3 or input_format == 4 or input_format == 132 or input_format == 5 or input_format == 133: # Each inner array has to be reversed as these fields are in opposite endian to the planes data. them_ooo_bytes = reverse_expand_bits(them_ooo) us_ooo_bytes = reverse_expand_bits(us_ooo) them_oo_bytes = reverse_expand_bits(them_oo) us_oo_bytes = reverse_expand_bits(us_oo) enpassant_bytes = reverse_expand_bits(stm) middle_planes = us_ooo_bytes + (6*8*4) * b'\x00' + them_ooo_bytes + \ us_oo_bytes + (6*8*4) * b'\x00' + them_oo_bytes + \ self.flat_planes[0] + \ self.flat_planes[0] + \ (7*8*4) * b'\x00' + enpassant_bytes # Concatenate all byteplanes. Make the last plane all 1's so the NN can # detect edges of the board more easily aux_plus_6_plane = self.flat_planes[0] if (input_format == 132 or input_format == 133) and invariance_info >= 128: aux_plus_6_plane = self.flat_planes[1] planes = planes.tobytes() + \ middle_planes + \ rule50_plane + \ aux_plus_6_plane + \ self.flat_planes[1] assert len(planes) == ((8 * 13 * 1 + 8 * 1 * 1) * 8 * 8 * 4) if ver == V6_VERSION: winner = struct.pack('fff', 0.5 * (1.0 - result_d + result_q), result_d, 0.5 * (1.0 - result_d - result_q)) else: dep_result = float(dep_result) assert dep_result == 1.0 or dep_result == -1.0 or dep_result == 0.0 winner = struct.pack('fff', dep_result == 1.0, dep_result == 0.0, dep_result == -1.0) best_q_w = 0.5 * (1.0 - best_d + best_q) best_q_l = 0.5 * (1.0 - best_d - best_q) assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0 best_q = struct.pack('fff', best_q_w, best_d, best_q_l) return (planes, probs, winner, best_q, plies_left) def sample_record(self, chunkdata): """ Randomly sample through the v3/4/5/6 chunk data and select records in v6 format Downsampling to avoid highly correlated positions skips most records, and diff focus may also skip some records. """ version = chunkdata[0:4] if version == V6_VERSION: record_size = self.v6_struct.size elif version == V5_VERSION: record_size = self.v5_struct.size elif version == V4_VERSION: record_size = self.v4_struct.size elif version == V3_VERSION: record_size = self.v3_struct.size else: return for i in range(0, len(chunkdata), record_size): if self.sample > 1: # Downsample, using only 1/Nth of the items. if random.randint(0, self.sample - 1) != 0: continue # Skip this record. record = chunkdata[i:i + record_size] # for earlier versions, append fake bytes to record to maintain size if version == V3_VERSION: # add 16 bytes of fake root_q, best_q, root_d, best_d to match V4 format record += 16 * b'\x00' if version == V3_VERSION or version == V4_VERSION: # add 12 bytes of fake root_m, best_m, plies_left to match V5 format record += 12 * b'\x00' # insert 4 bytes of classical input format tag to match v5 format record = record[:4] + CLASSICAL_INPUT + record[4:] if version == V3_VERSION or version == V4_VERSION or version == V5_VERSION: # add 48 byes of fake result_q, result_d etc record += 48 * b'\x00' if version == V6_VERSION: # diff focus code, peek at best_q, orig_q and pol_kld from record (unpacks as tuple with one item) best_q = struct.unpack('f', record[8284:8288])[0] orig_q = struct.unpack('f', record[8328:8332])[0] pol_kld = struct.unpack('f', record[8348:8352])[0] # if orig_q is NaN or pol_kld is 0, accept, else accept based on diff focus if not np.isnan(orig_q) and pol_kld > 0: diff_q = abs(best_q - orig_q) q_weight = self.diff_focus_q_weight pol_scale = self.diff_focus_pol_scale total = (q_weight * diff_q + pol_kld) / (q_weight + pol_scale) thresh_p = self.diff_focus_min + self.diff_focus_slope * total if thresh_p < 1.0 and random.random() > thresh_p: continue yield record def single_file_gen(self, filename): try: with gzip.open(filename, 'rb') as chunk_file: version = chunk_file.read(4) chunk_file.seek(0) if version == V6_VERSION: record_size = self.v6_struct.size elif version == V5_VERSION: record_size = self.v5_struct.size elif version == V4_VERSION: record_size = self.v4_struct.size elif version == V3_VERSION: record_size = self.v3_struct.size else: print('Unknown version {} in file {}'.format( version, filename)) return while True: chunkdata = chunk_file.read(256 * record_size) if len(chunkdata) == 0: break for item in self.sample_record(chunkdata): yield item except: print("failed to parse {}".format(filename)) def sequential_gen(self): for filename in self.chunks: for item in self.single_file_gen(filename): yield item def sequential(self): gen = self.sequential_gen() # read from all files in order in this process. gen = self.tuple_gen(gen) # convert v6->tuple gen = self.batch_gen(gen, allow_partial=False) # assemble into batches for b in gen: yield b def task(self, chunk_filename_queue, writer): """ Run in fork'ed process, read data from chunkdatasrc, parsing, shuffling and sending v6 data through pipe back to main process. """ self.init_structs() while True: filename = chunk_filename_queue.get() for item in self.single_file_gen(filename): writer.send_bytes(item) def v6_gen(self): """ Read v6 records from child workers, shuffle, and yield records. """ sbuff = sb.ShuffleBuffer(self.v6_struct.size, self.shuffle_size) while len(self.readers): for r in self.readers: try: s = r.recv_bytes() s = sbuff.insert_or_replace(s) if s is None: continue # shuffle buffer not yet full yield s except EOFError: print("Reader EOF") self.readers.remove(r) # drain the shuffle buffer. while True: s = sbuff.extract() if s is None: return yield s def tuple_gen(self, gen): """ Take a generator producing v6 records and convert them to tuples. applying a random symmetry on the way. """ for r in gen: yield self.convert_v6_to_tuple(r) def batch_gen(self, gen, allow_partial=True): """ Pack multiple records into a single batch """ # Get N records. We flatten the returned generator to # a list because we need to reuse it. while True: s = list(itertools.islice(gen, self.batch_size)) if not len(s) or (not allow_partial and len(s) != self.batch_size): return yield (b''.join([x[0] for x in s]), b''.join([x[1] for x in s]), b''.join([x[2] for x in s]), b''.join([x[3] for x in s]), b''.join([x[4] for x in s])) def parse(self): """ Read data from child workers and yield batches of unpacked records """ gen = self.v6_gen() # read from workers gen = self.tuple_gen(gen) # convert v6->tuple gen = self.batch_gen(gen) # assemble into batches for b in gen: yield b # Tests to check that records parse correctly class ChunkParserTest(unittest.TestCase): def setUp(self): self.v4_struct = struct.Struct(V4_STRUCT_STRING) def generate_fake_pos(self): """ Generate a random game position. Result is ([[64] * 104], [1]*5, [1858], [1], [1]) """ # 0. 104 binary planes of length 64 planes = [ np.random.randint(2, size=64).tolist() for plane in range(104) ] # 1. generate the other integer data integer = np.zeros(7, dtype=np.int32) for i in range(5): integer[i] = np.random.randint(2) integer[5] = np.random.randint(100) # 2. 1858 probs probs = np.random.randint(9, size=1858, dtype=np.int32) # 3. And a winner: 1, 0, -1 winner = np.random.randint(3) - 1 # 4. evaluation after search best_q = np.random.uniform(-1, 1) best_d = np.random.uniform(0, 1 - np.abs(best_q)) return (planes, integer, probs, winner, best_q, best_d) def v4_record(self, planes, i, probs, winner, best_q, best_d): pl = [] for plane in planes: pl.append(np.packbits(plane)) pl = np.array(pl).flatten().tobytes() pi = probs.tobytes() root_q, root_d = 0.0, 0.0 return self.v4_struct.pack(V4_VERSION, pi, pl, i[0], i[1], i[2], i[3], i[4], i[5], i[6], winner, root_q, best_q, root_d, best_d) def test_structsize(self): """ Test struct size """ self.assertEqual(self.v4_struct.size, 8292) def test_parsing(self): """ Test game position decoding pipeline. """ truth = self.generate_fake_pos() batch_size = 4 records = [] for i in range(batch_size): record = b'' for j in range(2): record += self.v4_record(*truth) records.append(record) parser = ChunkParser(ChunkDataSrc(records), shuffle_size=1, workers=1, batch_size=batch_size) batchgen = parser.parse() data = next(batchgen) batch = (np.reshape(np.frombuffer(data[0], dtype=np.float32), (batch_size, 112, 64)), np.reshape(np.frombuffer(data[1], dtype=np.int32), (batch_size, 1858)), np.reshape(np.frombuffer(data[2], dtype=np.float32), (batch_size, 3)), np.reshape(np.frombuffer(data[3], dtype=np.float32), (batch_size, 3))) fltplanes = truth[1].astype(np.float32) fltplanes[5] /= 99 for i in range(batch_size): data = (batch[0][i][:104], np.array([batch[0][i][j][0] for j in range(104, 111)]), batch[1][i], batch[2][i], batch[3][i]) self.assertTrue((data[0] == truth[0]).all()) self.assertTrue((data[1] == fltplanes).all()) self.assertTrue((data[2] == truth[2]).all()) scalar_win = data[3][0] - data[3][-1] self.assertTrue(np.abs(scalar_win - truth[3]) < 1e-6) scalar_q = data[4][0] - data[4][-1] self.assertTrue(np.abs(scalar_q - truth[4]) < 1e-6) parser.shutdown() if __name__ == '__main__': unittest.main() ================================================ FILE: tf/configs/example.yaml ================================================ %YAML 1.2 --- name: 'kb1-64x6' # ideally no spaces gpu: 0 # gpu id to process on dataset: num_chunks: 100000 # newest nof chunks to parse train_ratio: 0.90 # trainingset ratio # For separated test and train data. input_train: '/path/to/chunks/*/draw/' # supports glob input_test: '/path/to/chunks/*/draw/' # supports glob # For a one-shot run with all data in one directory. # input: '/path/to/chunks/*/draw/' training: batch_size: 2048 # training batch test_steps: 2000 # eval test set values after this many steps train_avg_report_steps: 200 # training reports its average values after this many steps. total_steps: 140000 # terminate after these steps warmup_steps: 250 # if global step is less than this, scale the current LR by ratio of global step to this value # checkpoint_steps: 10000 # optional frequency for checkpointing before finish shuffle_size: 524288 # size of the shuffle buffer lr_values: # list of learning rates - 0.02 - 0.002 - 0.0005 lr_boundaries: # list of boundaries - 100000 - 130000 policy_loss_weight: 1.0 # weight of policy loss value_loss_weight: 1.0 # weight of value loss moves_left_loss_weight: 1.0 # weight of moves-left loss path: '/path/to/store/networks' # network storage dir model: filters: 64 residual_blocks: 6 se_ratio: 2 # Squeeze Excite structural network architecture. policy: 'attention' # attention policy fields: pol_embedding_size: 64 # embedding vector size pol_encoder_layers: 1 # number of intermediate attention layers in the policy head pol_encoder_heads: 4 # number of attention heads in encoder layers pol_encoder_d_model: 64 # size of the Q, K, & V vectors in encoder layers -- divisible by encoder_heads pol_encoder_dff: 128 # size of the largest dense layer in encoder block feed-forward network policy_d_model: 64 # size of the query and key vectors in final attention layer value: 'wdl' moves_left: 'v1' ... ================================================ FILE: tf/decode_training.py ================================================ #!/usr/bin/env python3 # # This file is part of Leela Chess. # Copyright (C) 2018 Folkert Huizinga # Copyright (C) 2017-2018 Gian-Carlo Pascutto # # Leela Chess is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Leela Chess is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Leela Chess. If not, see . import array import binascii import chunkparser import glob import gzip import itertools import math import numpy as np import random import re import os import shutil import struct import sys import threading import time import unittest import argparse from collections import defaultdict # VERSION of the training data file format # 1 - Text, oldflip # 2 - Binary, oldflip # 3 - Binary, newflip # b'\1\0\0\0' - Invalid, see issue #119 # # Note: VERSION1 does not include a version in the header, it starts with # text hex characters. This means any VERSION that is also a valid ASCII # hex string could potentially be a training file. Most of the time it # will be "00ff", but maybe games could get split and then there are more # "00??" possibilities. # # Also note "0002" is actually b'\0x30\0x30\0x30\0x32' (or maybe reversed?) # so it doesn't collide with VERSION2. # VERSION3 = chunkparser.V3_VERSION VERSION4 = chunkparser.V4_VERSION V3_BYTES = 8276 V4_BYTES = 8292 # Us -- uppercase # Them -- lowercase PIECES = "PNBRQKpnbrqk" MOVES = [ "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2", "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6", "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1", "b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3", "b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7", "b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1", "c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3", "c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8", "d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2", "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4", "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1", "e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2", "e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4", "e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1", "f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3", "f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6", "f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1", "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3", "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8", "h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2", "h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6", "h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2", "a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3", "a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7", "a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2", "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4", "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7", "b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2", "c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3", "c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6", "c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1", "d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3", "d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5", "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1", "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2", "e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4", "e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1", "f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2", "f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4", "f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7", "f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2", "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4", "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8", "g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2", "h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5", "h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1", "a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3", "a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6", "a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1", "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3", "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5", "b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1", "c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3", "c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4", "c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6", "c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1", "d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3", "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5", "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7", "d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2", "e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3", "e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5", "e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1", "f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3", "f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4", "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6", "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2", "g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3", "g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5", "g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1", "h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3", "h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6", "h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2", "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4", "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7", "a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3", "b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4", "b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6", "b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2", "c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4", "c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5", "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8", "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2", "d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4", "d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6", "d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8", "e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3", "e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4", "e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6", "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1", "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3", "f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4", "f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6", "f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2", "g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4", "g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6", "g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1", "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4", "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6", "h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2", "a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5", "a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7", "a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3", "b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5", "b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7", "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2", "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4", "c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6", "c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7", "c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3", "d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5", "d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6", "d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8", "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3", "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5", "e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6", "e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8", "f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3", "f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5", "f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7", "f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2", "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4", "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6", "g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1", "h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4", "h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6", "h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2", "a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5", "a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7", "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3", "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5", "b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7", "b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2", "c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5", "c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6", "c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8", "c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3", "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5", "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7", "d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8", "e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4", "e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6", "e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7", "e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2", "f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5", "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6", "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8", "f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3", "g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6", "g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7", "g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2", "h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6", "h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7", "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3", "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7", "a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8", "b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5", "b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7", "b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8", "c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5", "c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7", "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8", "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4", "d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6", "d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8", "d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4", "e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6", "e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7", "e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2", "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5", "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7", "f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1", "g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5", "g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7", "g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1", "h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5", "h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7", "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2", "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6", "a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8", "a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5", "b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7", "b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2", "c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6", "c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8", "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4", "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6", "d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8", "d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5", "e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7", "e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8", "e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5", "f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7", "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1", "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6", "g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8", "g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2", "h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6", "h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8", "h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q", "b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b", "c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r", "c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q", "d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b", "e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r", "f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q", "g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b", "h7h8q", "h7h8r", "h7h8b" ] class Board: def __init__(self): self.clear_board() def clear_board(self): self.board = [] for rank in range(8): self.board.append(list("." * 8)) self.reps = 0 def describe(self): s = [] for rank in range(8): s.append("".join(self.board[rank])) s.append("reps {} ".format(self.reps)) return s class TrainingStep: def __init__(self, version): self.version = version # Construct a fake parser just to get access to it's variables self.parser = chunkparser.ChunkParser(chunkparser.ChunkDataSrc([]), workers=1) self.NUM_HIST = 8 self.NUM_PIECE_TYPES = 6 self.V3_NUM_PLANES = self.NUM_PIECE_TYPES * 2 + 1 # = 13 (6*2 us/them pieces, rep1 (no rep2)) self.NUM_PLANES = self.V3_NUM_PLANES self.NUM_REALS = 7 # 4 castling, 1 color, 1 50rule, 1 movecount self.NUM_OUTPUTS = 2 # policy, value self.NUM_PLANES_BYTES = self.NUM_PLANES * 4 self.NUM_PLANES_BYTES = self.NUM_PLANES * 4 self.NUM_PLANES_BYTES = self.NUM_PLANES * 4 self.V3_NUM_POLICY_MOVES = 1858 # (7432 bytes) self.NUM_POLICY_MOVES = self.V3_NUM_POLICY_MOVES self.init_structs() self.init_move_map() self.history = [] self.probs = [] for history in range(self.NUM_HIST): self.history.append(Board()) self.us_ooo = 0 self.us_oo = 0 self.them_ooo = 0 self.them_oo = 0 self.us_black = 0 self.rule50_count = 0 self.winner = None self.q = None def init_structs(self): self.v4_struct = self.parser.v4_struct self.this_struct = self.v4_struct def init_move_map(self): self.new_white_move_map = defaultdict(lambda: -1) self.new_black_move_map = defaultdict(lambda: -1) self.old_rev_move_map = {} self.new_rev_white_move_map = {} self.new_rev_black_move_map = {} for idx, m in enumerate(MOVES): self.new_white_move_map[m] = idx self.new_rev_white_move_map[idx] = m m_black = m.translate(str.maketrans("12345678", "87654321")) self.new_black_move_map[m_black] = idx self.new_rev_black_move_map[idx] = m_black def clear_hist(self): for hist in range(self.NUM_HIST): self.history.clear_board() def update_board(self, hist, piece, bit_board): """ Update the ASCII board representation """ for r in range(8): for f in range(8): # Note: Using 8-1-f because both the text and binary have the # column bits reversed fhom what this code expects if bit_board & (1 << (r * 8 + (8 - 1 - f))): assert (self.history[hist].board[r][f] == ".") self.history[hist].board[r][f] = piece def describe(self): s = "" if self.us_black: s += "us = Black" else: s += "us = White" if self.winner == 1: s += " won\n" elif self.winner == -1: s += " lost\n" elif self.winner == 0: s += " draw\n" else: raise Exception("Invalid winner: {}".format(self.winner)) s += "Root Q = {} (diff to result: {}) \n".format( self.root_q, abs(self.winner - self.root_q)) s += "Best Q = {} (diff to result: {}) \n".format( self.best_q, abs(self.winner - self.best_q)) if self.us_black: s += "(Note the black pieces are CAPS, black moves up, but A1 is in lower left)\n" s += "rule50_count {} b_ooo b_oo, w_ooo, w_oo {} {} {} {}\n".format( self.rule50_count, self.us_ooo, self.us_oo, self.them_ooo, self.them_oo) s += " abcdefgh\n" rank_strings = [[]] for rank in reversed(range(8)): rank_strings[0].append("{}".format(rank + 1)) rank_strings[0].append(" ") for hist in range(self.NUM_HIST): rank_strings.append(self.history[hist].describe()) for hist in range(self.NUM_HIST + 1): for rank in range(8 + 1): #if hist == 8 and rank == 0: continue s += rank_strings[rank][hist] + " " s += "\n" sum = 0.0 top_moves = {} for idx, prob in enumerate(self.probs): # Include all moves with at least 1 visit. condition = prob > 0 if self.version == 3 else prob >= 0 if condition: top_moves[idx] = prob sum += prob for idx, prob in sorted(top_moves.items(), key=lambda x: -x[1]): s += "{} {:4.1f}%\n".format(self.new_rev_white_move_map[idx], prob * 100) #print("debug prob sum", sum, "cnt", len(self.probs)) return s def update_reals(self, text_item): self.us_ooo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 0]) self.us_oo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 1]) self.them_ooo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 2]) self.them_oo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 3]) self.us_black = int(text_item[self.NUM_HIST * self.NUM_PLANES + 4]) self.rule50_count = min( int(text_item[self.NUM_HIST * self.NUM_PLANES + 5]), 255) # should be around 99-102ish assert self.rule50_count < 105 def flip_single_v1_plane(self, plane): # Split hexstring into bytes (2 ascii chars), reverse, rejoin # This causes a vertical flip return "".join( [plane[x:x + 2] for x in reversed(range(0, len(plane), 2))]) def display_v4(self, ply, content): (ver, probs, planes, us_ooo, us_oo, them_ooo, them_oo, us_black, rule50_count, move_count, winner, root_q, best_q, root_d, best_d) = self.this_struct.unpack(content) assert self.version == int.from_bytes(ver, byteorder="little") # Enforce move_count to 0 move_count = 0 # Unpack planes. for hist in range(self.NUM_HIST): for idx, piece in enumerate(PIECES): start = hist * self.NUM_PLANES * 8 + idx * 8 end = start + 8 self.update_board( hist, piece, int.from_bytes(planes[start:end], byteorder="big")) if planes[hist * self.NUM_PLANES * 8 + 12 * 8:hist * self.NUM_PLANES * 8 + 12 * 8 + 8] != struct.pack('II', 0, 0): self.history[hist].reps = 1 assert planes[hist * self.NUM_PLANES * 8 + 12 * 8:hist * self.NUM_PLANES * 8 + 12 * 8 + 8] == struct.pack('II', 0xffffffff, 0xffffffff) self.us_ooo = us_ooo self.us_oo = us_oo self.them_ooo = them_ooo self.them_oo = them_oo self.us_black = us_black self.rule50_count = rule50_count self.winner = winner self.root_q = root_q self.best_q = best_q for idx in range(0, len(probs), 4): self.probs.append(struct.unpack("f", probs[idx:idx + 4])[0]) print("ply {} move {} (Not actually part of training data)".format( ply + 1, (ply + 2) // 2)) print(self.describe()) def main(args): for filename in args.files: #print("Parsing {}".format(filename)) with gzip.open(filename, 'rb') as f: chunkdata = f.read() version = chunkdata[0:4] if version in {VERSION4, VERSION3}: if version == VERSION3: record_size = V3_BYTES else: record_size = V4_BYTES for i in range(0, len(chunkdata), record_size): ts = TrainingStep(4 if version == VERSION4 else 3) record = chunkdata[i:i + record_size] if chunkdata[0:4] == VERSION3: record += 16 * b'\x00' ts.display_v4(i // record_size, record) else: print("Invalid version") if __name__ == '__main__': usage_str = """ Parse training files and display them.""" parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=usage_str) parser.add_argument("files", type=str, nargs="+", help="training*.gz") args = parser.parse_args() main(args) ================================================ FILE: tf/lc0_az_policy_map.py ================================================ #!/usr/bin/env python3 import sys import numpy as np from policy_index import policy_index columns = 'abcdefgh' rows = '12345678' promotions = 'rbq' # N is encoded as normal move col_index = {columns[i]: i for i in range(len(columns))} row_index = {rows[i]: i for i in range(len(rows))} def index_to_position(x): return columns[x[0]] + rows[x[1]] def position_to_index(p): return col_index[p[0]], row_index[p[1]] def valid_index(i): if i[0] > 7 or i[0] < 0: return False if i[1] > 7 or i[1] < 0: return False return True def queen_move(start, direction, steps): i = position_to_index(start) dir_vectors = { 'N': (0, 1), 'NE': (1, 1), 'E': (1, 0), 'SE': (1, -1), 'S': (0, -1), 'SW': (-1, -1), 'W': (-1, 0), 'NW': (-1, 1) } v = dir_vectors[direction] i = i[0] + v[0] * steps, i[1] + v[1] * steps if not valid_index(i): return None return index_to_position(i) def knight_move(start, direction, steps): i = position_to_index(start) dir_vectors = { 'N': (1, 2), 'NE': (2, 1), 'E': (2, -1), 'SE': (1, -2), 'S': (-1, -2), 'SW': (-2, -1), 'W': (-2, 1), 'NW': (-1, 2) } v = dir_vectors[direction] i = i[0] + v[0] * steps, i[1] + v[1] * steps if not valid_index(i): return None return index_to_position(i) def make_map(kind='matrix'): # 56 planes of queen moves moves = [] for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: for steps in range(1, 8): for r0 in rows: for c0 in columns: start = c0 + r0 end = queen_move(start, direction, steps) if end == None: moves.append('illegal') else: moves.append(start + end) # 8 planes of knight moves for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: for r0 in rows: for c0 in columns: start = c0 + r0 end = knight_move(start, direction, 1) if end == None: moves.append('illegal') else: moves.append(start + end) # 9 promotions for direction in ['NW', 'N', 'NE']: for promotion in promotions: for r0 in rows: for c0 in columns: # Promotion only in the second last rank if r0 != '7': moves.append('illegal') continue start = c0 + r0 end = queen_move(start, direction, 1) if end == None: moves.append('illegal') else: moves.append(start + end + promotion) for m in policy_index: if m not in moves: raise ValueError('Missing move: {}'.format(m)) az_to_lc0 = np.zeros((80 * 8 * 8, len(policy_index)), dtype=np.float32) indices = [] legal_moves = 0 for e, m in enumerate(moves): if m == 'illegal': indices.append(-1) continue legal_moves += 1 # Check for missing moves if m not in policy_index: raise ValueError('Missing move: {}'.format(m)) i = policy_index.index(m) indices.append(i) az_to_lc0[e][i] = 1 assert legal_moves == len(policy_index) assert np.sum(az_to_lc0) == legal_moves for e in range(80 * 8 * 8): for i in range(len(policy_index)): pass if kind == 'matrix': return az_to_lc0 elif kind == 'index': return indices if __name__ == "__main__": # Generate policy map include file for lc0 if len(sys.argv) != 2: raise ValueError( "Output filename is needed as a command line argument") az_to_lc0 = np.ravel(make_map('index')) header = \ """/* This file is part of Leela Chess Zero. Copyright (C) 2019 The LCZero Authors Leela Chess is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Leela Chess is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Leela Chess. If not, see . */ #pragma once namespace lczero { """ line_length = 12 with open(sys.argv[1], 'w') as f: f.write(header + '\n') f.write('const short kConvPolicyMap[] = {\\\n') for e, i in enumerate(az_to_lc0): if e % line_length == 0 and e > 0: f.write('\n') f.write(str(i).rjust(5)) if e != len(az_to_lc0) - 1: f.write(',') f.write('};\n\n') f.write('} // namespace lczero') ================================================ FILE: tf/make_model.py ================================================ #!/usr/bin/env python3 import argparse import os import yaml import tfprocess argparser = argparse.ArgumentParser(description='Convert net to model.') argparser.add_argument('--start', type=int, default=0, help='Offset to set global_step to.') argparser.add_argument('--cfg', type=argparse.FileType('r'), help='yaml configuration with training parameters') args = argparser.parse_args() cfg = yaml.safe_load(args.cfg.read()) print(yaml.dump(cfg, default_flow_style=False)) START_FROM = args.start tfp = tfprocess.TFProcess(cfg) tfp.init_net() tfp.global_step.assign(START_FROM) root_dir = os.path.join(cfg['training']['path'], cfg['name']) if not os.path.exists(root_dir): os.makedirs(root_dir) tfp.manager.save(checkpoint_number=START_FROM) print("Wrote model to {}".format(tfp.manager.latest_checkpoint)) path = os.path.join(tfp.root_dir, tfp.cfg['name']) leela_path = path + "-" + str(START_FROM) swa_path = path + "-swa-" + str(START_FROM) tfp.net.pb.training_params.training_steps = START_FROM tfp.save_leelaz_weights(leela_path) if tfp.swa_enabled: tfp.save_swa_weights(swa_path) ================================================ FILE: tf/model_to_net.py ================================================ #!/usr/bin/env python3 import argparse import os import yaml import tfprocess argparser = argparse.ArgumentParser(description='Convert model to net.') argparser.add_argument('--cfg', type=argparse.FileType('r'), help='yaml configuration with training parameters') args = argparser.parse_args() cfg = yaml.safe_load(args.cfg.read()) print(yaml.dump(cfg, default_flow_style=False)) tfp = tfprocess.TFProcess(cfg) tfp.init_net() tfp.restore() root_dir = os.path.join(cfg['training']['path'], cfg['name']) if not os.path.exists(root_dir): os.makedirs(root_dir) path = os.path.join(tfp.root_dir, tfp.cfg['name']) steps = tfp.global_step.read_value().numpy() leela_path = path + "-" + str(steps) swa_path = path + "-swa-" + str(steps) tfp.net.pb.training_params.training_steps = steps tfp.save_leelaz_weights(leela_path) if tfp.swa_enabled: tfp.save_swa_weights(swa_path) ================================================ FILE: tf/net.py ================================================ #!/usr/bin/env python3 import argparse import gzip import os import numpy as np import proto.net_pb2 as pb LC0_MAJOR = 0 LC0_MINOR = 21 LC0_MINOR_WITH_INPUT_TYPE_3 = 25 LC0_MINOR_WITH_INPUT_TYPE_4 = 26 LC0_MINOR_WITH_INPUT_TYPE_5 = 27 LC0_MINOR_WITH_MISH = 29 LC0_MINOR_WITH_ATTN_BODY = 30 LC0_PATCH = 0 WEIGHTS_MAGIC = 0x1c0 def nested_getattr(obj, attr): attributes = attr.split(".") for a in attributes: obj = getattr(obj, a) return obj class Net: def __init__(self, net=pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT, input=pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE, value=pb.NetworkFormat.VALUE_CLASSICAL, policy=pb.NetworkFormat.POLICY_CLASSICAL, moves_left=pb.NetworkFormat.MOVES_LEFT_V1): if net == pb.NetworkFormat.NETWORK_SE: net = pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT if net == pb.NetworkFormat.NETWORK_CLASSICAL: net = pb.NetworkFormat.NETWORK_CLASSICAL_WITH_HEADFORMAT self.pb = pb.Net() self.pb.magic = WEIGHTS_MAGIC self.pb.min_version.major = LC0_MAJOR self.pb.min_version.minor = LC0_MINOR self.pb.min_version.patch = LC0_PATCH self.pb.format.weights_encoding = pb.Format.LINEAR16 self.weights = [] self.set_networkformat(net) self.pb.format.network_format.input = input self.set_policyformat(policy) self.set_valueformat(value) self.set_movesleftformat(moves_left) self.set_defaultactivation(pb.NetworkFormat.DEFAULT_ACTIVATION_RELU) def set_networkformat(self, net): self.pb.format.network_format.network = net if net == pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT \ and self.pb.min_version.minor < LC0_MINOR_WITH_ATTN_BODY: self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY def set_policyformat(self, policy): self.pb.format.network_format.policy = policy def set_headcount(self, headcount): self.pb.weights.headcount = headcount def set_pol_headcount(self, headcount): self.pb.weights.pol_headcount = headcount def set_valueformat(self, value): self.pb.format.network_format.value = value # OutputFormat is for search to know which kind of value the net returns. if value == pb.NetworkFormat.VALUE_WDL: self.pb.format.network_format.output = pb.NetworkFormat.OUTPUT_WDL else: self.pb.format.network_format.output = pb.NetworkFormat.OUTPUT_CLASSICAL def set_movesleftformat(self, moves_left): self.pb.format.network_format.moves_left = moves_left def set_input(self, input_format): self.pb.format.network_format.input = input_format if input_format == pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2 or input_format == pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON: self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_5 elif input_format >= pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES: self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_4 # Input type 2 was available before 3, but it was buggy, so also limit it to same version as 3. elif input_format != pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE: self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_3 def set_defaultactivation(self, activation): self.pb.format.network_format.default_activation = activation if activation == pb.NetworkFormat.DEFAULT_ACTIVATION_MISH: if self.pb.min_version.minor < LC0_MINOR_WITH_MISH: self.pb.min_version.minor = LC0_MINOR_WITH_MISH def set_smolgen_activation(self, activation): self.pb.format.network_format.smolgen_activation = activation if self.pb.min_version.minor < LC0_MINOR_WITH_ATTN_BODY: self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY return None def set_ffn_activation(self, activation): self.pb.format.network_format.ffn_activation = activation if self.pb.min_version.minor < LC0_MINOR_WITH_ATTN_BODY: self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY return None def activation(self, name): if name == "relu": return pb.NetworkFormat.ACTIVATION_RELU elif name == "tanh": return pb.NetworkFormat.ACTIVATION_TANH elif name == "sigmoid": return pb.NetworkFormat.ACTIVATION_SIGMOID elif name == "softmax": return pb.NetworkFormat.ACTIVATION_SOFTMAX elif name == "selu": return pb.NetworkFormat.ACTIVATION_SELU elif name == "mish": return pb.NetworkFormat.ACTIVATION_MISH elif name == "swish": return pb.NetworkFormat.ACTIVATION_SWISH elif name == "relu_2" or name == "sqrrelu": return pb.NetworkFormat.ACTIVATION_RELU_2 elif name == "default": return pb.NetworkFormat.ACTIVATION_DEFAULT else: return pb.NetworkFormat.ACTIVATION_NONE def get_weight_amounts(self): value_weights = 8 policy_weights = 6 head_weights = value_weights + policy_weights if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT: # Batch norm gammas in head convolutions. head_weights += 2 if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT: return {"input": 5, "residual": 14, "head": head_weights} else: return {"input": 4, "residual": 8, "head": head_weights} def fill_layer_v2(self, layer, params): """Normalize and populate 16bit layer in protobuf""" params = params.flatten().astype(np.float32) layer.min_val = 0 if len(params) == 1 else float(np.min(params)) layer.max_val = 1 if len(params) == 1 and np.max( params) == 0 else float(np.max(params)) if layer.max_val == layer.min_val: # Avoid division by zero if max == min. params = (params - layer.min_val) else: params = (params - layer.min_val) / (layer.max_val - layer.min_val) params *= 0xffff params = np.round(params) layer.params = params.astype(np.uint16).tobytes() def fill_layer(self, layer, weights): """Normalize and populate 16bit layer in protobuf""" params = np.array(weights.pop(), dtype=np.float32) layer.min_val = 0 if len(params) == 1 else float(np.min(params)) layer.max_val = 1 if len(params) == 1 and np.max( params) == 0 else float(np.max(params)) if layer.max_val == layer.min_val: # Avoid division by zero if max == min. params = (params - layer.min_val) else: params = (params - layer.min_val) / (layer.max_val - layer.min_val) params *= 0xffff params = np.round(params) layer.params = params.astype(np.uint16).tobytes() def fill_conv_block(self, convblock, weights, gammas): """Normalize and populate 16bit convblock in protobuf""" if gammas: self.fill_layer(convblock.bn_stddivs, weights) self.fill_layer(convblock.bn_means, weights) self.fill_layer(convblock.bn_betas, weights) self.fill_layer(convblock.bn_gammas, weights) self.fill_layer(convblock.weights, weights) else: self.fill_layer(convblock.bn_stddivs, weights) self.fill_layer(convblock.bn_means, weights) self.fill_layer(convblock.biases, weights) self.fill_layer(convblock.weights, weights) def fill_plain_conv(self, convblock, weights): """Normalize and populate 16bit convblock in protobuf""" self.fill_layer(convblock.biases, weights) self.fill_layer(convblock.weights, weights) def fill_se_unit(self, se_unit, weights): self.fill_layer(se_unit.b2, weights) self.fill_layer(se_unit.w2, weights) self.fill_layer(se_unit.b1, weights) self.fill_layer(se_unit.w1, weights) def denorm_layer_v2(self, layer): """Denormalize a layer from protobuf""" params = np.frombuffer(layer.params, np.uint16).astype(np.float32) params /= 0xffff return params * (layer.max_val - layer.min_val) + layer.min_val def denorm_layer(self, layer, weights): weights.insert(0, self.denorm_layer_v2(layer)) def denorm_conv_block(self, convblock, weights): """Denormalize a convblock from protobuf""" se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT if se: self.denorm_layer(convblock.bn_stddivs, weights) self.denorm_layer(convblock.bn_means, weights) self.denorm_layer(convblock.bn_betas, weights) self.denorm_layer(convblock.bn_gammas, weights) self.denorm_layer(convblock.weights, weights) else: self.denorm_layer(convblock.bn_stddivs, weights) self.denorm_layer(convblock.bn_means, weights) self.denorm_layer(convblock.biases, weights) self.denorm_layer(convblock.weights, weights) def denorm_plain_conv(self, convblock, weights): """Denormalize a plain convolution from protobuf""" self.denorm_layer(convblock.biases, weights) self.denorm_layer(convblock.weights, weights) def denorm_se_unit(self, convblock, weights): """Denormalize SE-unit from protobuf""" se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT assert se self.denorm_layer(convblock.b2, weights) self.denorm_layer(convblock.w2, weights) self.denorm_layer(convblock.b1, weights) self.denorm_layer(convblock.w1, weights) def save_txt(self, filename): """Save weights as txt file""" weights = self.get_weights() if len(filename.split('.')) == 1: filename += ".txt.gz" # Legacy .txt files are version 2, SE is version 3. version = 2 if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT: version = 3 if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION: version = 4 with gzip.open(filename, 'wb') as f: f.write("{}\n".format(version).encode('utf-8')) for row in weights: f.write( (" ".join(map(str, row.tolist())) + "\n").encode('utf-8')) size = os.path.getsize(filename) / 1024**2 print("saved as '{}' {}M".format(filename, round(size, 2))) def save_proto(self, filename): """Save weights gzipped protobuf file""" if len(filename.split('.')) == 1: filename += ".pb.gz" with gzip.open(filename, 'wb') as f: data = self.pb.SerializeToString() f.write(data) size = os.path.getsize(filename) / 1024**2 print("Weights saved as '{}' {}M".format(filename, round(size, 2))) def tf_name_to_pb_name(self, name): """Given Tensorflow variable name returns the protobuf name and index of residual block if weight belong in a residual block.""" def convblock_to_bp(w): w = w.split(':')[0] d = { 'kernel': 'weights', 'gamma': 'bn_gammas', 'beta': 'bn_betas', 'moving_mean': 'bn_means', 'moving_variance': 'bn_stddivs', 'bias': 'biases' } return d[w] def se_to_bp(l, w): if l == 'dense1': n = 1 elif l == 'dense2': n = 2 else: raise ValueError('Unable to decode SE-weight {}/{}'.format( l, w)) w = w.split(':')[0] d = {'kernel': 'w', 'bias': 'b'} return d[w] + str(n) def value_to_bp(l, w): if l == 'embedding': n = '' elif l == 'dense1': n = 1 elif l == 'dense2': n = 2 else: raise ValueError('Unable to decode value weight {}/{}'.format( l, w)) w = w.split(':')[0] d = {'kernel': 'ip{}_val_w', 'bias': 'ip{}_val_b'} return d[w].format(n) def conv_policy_to_bp(w): w = w.split(':')[0] d = {'kernel': 'ip_pol_w', 'bias': 'ip_pol_b'} return d[w] def attn_pol_to_bp(l, w): if l == 'wq': n = 2 elif l == 'wk': n = 3 elif l == 'ppo': n = 4 else: raise ValueError( 'Unable to decode attn_policy weight {}/{}'.format(l, w)) w = w.split(':')[0] d = {'kernel': 'ip{}_pol_w', 'bias': 'ip{}_pol_b'} return d[w].format(n) def encoder_to_bp(l, w): if l == 'ln1': n = 1 elif l == 'ln2': n = 2 else: raise ValueError( 'Unable to decode encoder weight {}/{}'.format(l, w)) w = w.split(':')[0] d = {'gamma': 'ln{}_gammas', 'beta': 'ln{}_betas'} return d[w].format(n) def mha_to_bp(l, w): s = '' if l.startswith('dense'): s = 'dense' elif l.startswith('w'): s = l[1] else: raise ValueError('Unable to decode mha weight {}/{}'.format( l, w)) w = w.split(':')[0] d = {'kernel': '{}_w', 'bias': '{}_b'} return d[w].format(s) def mha_smolgen_to_bp(l, w): s = { 'compress': 'compress', 'hidden1_dense': 'dense1_{}', 'hidden1_ln': 'ln1_{}', 'gen_from': 'dense2_{}', 'gen_from_ln': 'ln2_{}' } if s[l] is None: raise ValueError( 'Unable to decode mha smolgen weight {}/{}'.format(l, w)) w = w.split(':')[0] d = { 'kernel': 'w', 'bias': 'b', 'gamma': 'gammas', 'beta': 'betas' } return s[l].format(d[w]) def ffn_to_bp(l, w): w = w.split(':')[0] d = {'kernel': '{}_w', 'bias': '{}_b'} return d[w].format(l) def moves_left_to_bp(l, w): if l == 'embedding': n = '' elif l == 'dense1': n = 1 elif l == 'dense2': n = 2 else: raise ValueError( 'Unable to decode moves_left weight {}/{}'.format(l, w)) w = w.split(':')[0] d = {'kernel': 'ip{}_mov_w', 'bias': 'ip{}_mov_b'} return d[w].format(n) layers = name.split('/') base_layer = layers[0] weights_name = layers[-1] pb_name = None block = None encoder_block = None pol_encoder_block = None if base_layer == 'input': pb_name = 'input.' + convblock_to_bp(weights_name) elif base_layer == 'policy1': pb_name = 'policy1.' + convblock_to_bp(weights_name) elif base_layer == 'policy': if 'dense' in layers[1]: pb_name = conv_policy_to_bp(weights_name) elif layers[1] == 'embedding': if layers[2].split(':')[0] == 'kernel': pb_name = 'ip_pol_w' else: pb_name = 'ip_pol_b' elif layers[1] == 'attention': pb_name = attn_pol_to_bp(layers[2], weights_name) elif layers[1].startswith('enc_layer_'): pol_encoder_block = int(layers[1].split('_')[2]) - 1 if layers[2] == 'mha': pb_name = 'mha.' + mha_to_bp(layers[3], weights_name) elif layers[2] == 'ffn': pb_name = 'ffn.' + ffn_to_bp(layers[3], weights_name) else: pb_name = encoder_to_bp(layers[2], weights_name) else: pb_name = 'policy.' + convblock_to_bp(weights_name) elif base_layer == 'value': if 'dense' in layers[1] or 'embedding' in layers[1]: pb_name = value_to_bp(layers[1], weights_name) else: pb_name = 'value.' + convblock_to_bp(weights_name) elif base_layer == 'moves_left': if 'dense' in layers[1] or 'embedding' in layers[1]: pb_name = moves_left_to_bp(layers[1], weights_name) else: pb_name = 'moves_left.' + convblock_to_bp(weights_name) elif base_layer.startswith('residual'): block = int(base_layer.split('_')[1]) - 1 # 1 indexed if layers[1] == '1': pb_name = 'conv1.' + convblock_to_bp(weights_name) elif layers[1] == '2': pb_name = 'conv2.' + convblock_to_bp(weights_name) elif layers[1] == 'se': pb_name = 'se.' + se_to_bp(layers[-2], weights_name) elif base_layer.startswith('encoder'): encoder_block = int(base_layer.split('_')[1]) - 1 if layers[1] == 'mha': if layers[2] == 'smolgen': pb_name = 'mha.smolgen.' + mha_smolgen_to_bp( layers[3], weights_name) else: pb_name = 'mha.' + mha_to_bp(layers[2], weights_name) elif layers[1] == 'ffn': pb_name = 'ffn.' + ffn_to_bp(layers[2], weights_name) else: pb_name = encoder_to_bp(layers[1], weights_name) elif base_layer == 'embedding': if layers[1] == 'mult_gate' or layers[1] == 'add_gate': if layers[2].split(':')[0] == 'gate': pb_name = 'ip_{}'.format(layers[1]) elif layers[1].split(':')[0] == 'kernel': pb_name = 'ip_emb_w' elif layers[1].split(':')[0] == 'bias': pb_name = 'ip_emb_b' elif base_layer == 'smol_weight_gen': if layers[1].split(':')[0] == 'kernel': pb_name = 'smolgen_w' else: pb_name = 'smolgen_b' return (pb_name, block, pol_encoder_block, encoder_block) def get_weights_v2(self, names): # `names` is a list of Tensorflow tensor names to get from the protobuf. # Returns list of [Tensor name, Tensor weights]. tensors = {} for tf_name in names: name = tf_name if 'stddev' in name: # Get variance instead of stddev. name = name.replace('stddev', 'variance') if 'renorm' in name: # Renorm variables are not populated. continue if 'headcount' in tf_name: # headcount is set with set_headcount() continue pb_name, block, pol_encoder_block, encoder_block = self.tf_name_to_pb_name( name) if pb_name is None: raise ValueError( "Don't know where to store weight in protobuf: {}".format( name)) if block is None: if pol_encoder_block is not None: pb_weights = self.pb.weights.pol_encoder[pol_encoder_block] elif encoder_block is not None: pb_weights = self.pb.weights.encoder[encoder_block] else: pb_weights = self.pb.weights else: pb_weights = self.pb.weights.residual[block] w = self.denorm_layer_v2(nested_getattr(pb_weights, pb_name)) # Only variance is stored in the protobuf. if 'stddev' in tf_name: w = np.sqrt(w + 1e-5) tensors[tf_name] = w return tensors def get_weights(self): """Returns the weights as floats per layer""" se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT if self.weights == []: self.denorm_layer(self.pb.weights.ip2_val_b, self.weights) self.denorm_layer(self.pb.weights.ip2_val_w, self.weights) self.denorm_layer(self.pb.weights.ip1_val_b, self.weights) self.denorm_layer(self.pb.weights.ip1_val_w, self.weights) self.denorm_conv_block(self.pb.weights.value, self.weights) if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION: self.denorm_plain_conv(self.pb.weights.policy, self.weights) self.denorm_conv_block(self.pb.weights.policy1, self.weights) else: self.denorm_layer(self.pb.weights.ip_pol_b, self.weights) self.denorm_layer(self.pb.weights.ip_pol_w, self.weights) self.denorm_conv_block(self.pb.weights.policy, self.weights) for res in reversed(self.pb.weights.residual): if se: self.denorm_se_unit(res.se, self.weights) self.denorm_conv_block(res.conv2, self.weights) self.denorm_conv_block(res.conv1, self.weights) self.denorm_conv_block(self.pb.weights.input, self.weights) return self.weights def filters(self): layer = self.pb.weights.input.bn_means params = np.frombuffer(layer.params, np.uint16).astype(np.float32) return len(params) def blocks(self): return len(self.pb.weights.residual) def print_stats(self): print("Blocks: {}".format(self.blocks())) print("Filters: {}".format(self.filters())) print_pb_stats(self.pb) print() def parse_proto(self, filename): with gzip.open(filename, 'rb') as f: self.pb = self.pb.FromString(f.read()) # Populate policyFormat and valueFormat fields in old protobufs # without these fields. if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE: self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT) self.set_valueformat(pb.NetworkFormat.VALUE_CLASSICAL) self.set_policyformat(pb.NetworkFormat.POLICY_CLASSICAL) self.set_movesleftformat(pb.NetworkFormat.MOVES_LEFT_NONE) elif self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_CLASSICAL: self.set_networkformat( pb.NetworkFormat.NETWORK_CLASSICAL_WITH_HEADFORMAT) self.set_valueformat(pb.NetworkFormat.VALUE_CLASSICAL) self.set_policyformat(pb.NetworkFormat.POLICY_CLASSICAL) self.set_movesleftformat(pb.NetworkFormat.MOVES_LEFT_NONE) def parse_txt(self, filename): weights = [] with open(filename, 'r') as f: try: version = int(f.readline()[0]) except: raise ValueError('Unable to read version.') for e, line in enumerate(f): weights.append(list(map(float, line.split(' ')))) if version == 3: self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT) if version == 4: self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT) self.set_policyformat(pb.NetworkFormat.POLICY_CONVOLUTION) self.fill_net(weights) def fill_net_v2(self, all_weights): # all_weights is array of [name of weight, numpy array of weights]. self.pb.format.weights_encoding = pb.Format.LINEAR16 has_renorm = any('renorm' in w[0] for w in all_weights) weight_names = [w[0] for w in all_weights] del self.pb.weights.residual[:] for name, weights in all_weights: layers = name.split('/') weights_name = layers[-1] if weights.ndim == 4: # Convolution weights need a transpose # # TF # [filter_height, filter_width, in_channels, out_channels] # # Leela # [output, input, filter_size, filter_size] weights = np.transpose(weights, axes=[3, 2, 0, 1]) elif weights.ndim == 2: # Fully connected layers are [in, out] in TF # # [out, in] in Leela # weights = np.transpose(weights, axes=[1, 0]) if 'renorm' in name: # Batch renorm has extra weights, but we don't know what to do with them. continue if has_renorm: if 'variance:' in weights_name: # Renorm has variance, but it is not the primary source of truth. continue # Renorm has moving stddev not variance, undo the transform to make it compatible. if 'stddev:' in weights_name: weights = np.square(weights) - 1e-5 name = name.replace('stddev', 'variance') if self.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES: if name == 'input/conv2d/kernel:0': # 50 move rule is the 110th input, or 109 starting from 0. weights[:, 109, :, :] /= 99 elif name == 'embedding/kernel:0': weights[:, 109] /= 99 pb_name, block, pol_encoder_block, encoder_block = self.tf_name_to_pb_name( name) if pb_name is None: raise ValueError( "Don't know where to store weight in protobuf: {}".format( name)) if block is None: if pol_encoder_block is not None: assert pol_encoder_block >= 0 while pol_encoder_block >= len( self.pb.weights.pol_encoder): self.pb.weights.pol_encoder.add() pb_weights = self.pb.weights.pol_encoder[pol_encoder_block] elif encoder_block is not None: assert encoder_block >= 0 while encoder_block >= len(self.pb.weights.encoder): self.pb.weights.encoder.add() pb_weights = self.pb.weights.encoder[encoder_block] else: pb_weights = self.pb.weights else: assert block >= 0 while block >= len(self.pb.weights.residual): self.pb.weights.residual.add() pb_weights = self.pb.weights.residual[block] self.fill_layer_v2(nested_getattr(pb_weights, pb_name), weights) if pb_name.endswith('bn_betas'): # Check if we need to add constant one gammas. gamma_name = name.replace('beta', 'gamma') if gamma_name in weight_names: continue gamma = np.ones(weights.shape) pb_gamma = pb_name.replace('bn_betas', 'bn_gammas') self.fill_layer_v2(nested_getattr(pb_weights, pb_gamma), gamma) def fill_net(self, weights): self.weights = [] # Batchnorm gammas in ConvBlock? se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT gammas = se ws = self.get_weight_amounts() blocks = len(weights) - (ws['input'] + ws['head']) if blocks % ws['residual'] != 0: raise ValueError("Inconsistent number of weights in the file") blocks //= ws['residual'] self.pb.format.weights_encoding = pb.Format.LINEAR16 self.fill_layer(self.pb.weights.ip2_val_b, weights) self.fill_layer(self.pb.weights.ip2_val_w, weights) self.fill_layer(self.pb.weights.ip1_val_b, weights) self.fill_layer(self.pb.weights.ip1_val_w, weights) self.fill_conv_block(self.pb.weights.value, weights, gammas) if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION: self.fill_plain_conv(self.pb.weights.policy, weights) self.fill_conv_block(self.pb.weights.policy1, weights, gammas) else: self.fill_layer(self.pb.weights.ip_pol_b, weights) self.fill_layer(self.pb.weights.ip_pol_w, weights) self.fill_conv_block(self.pb.weights.policy, weights, gammas) del self.pb.weights.residual[:] tower = [] for i in range(blocks): tower.append(self.pb.weights.residual.add()) for res in reversed(tower): if se: self.fill_se_unit(res.se, weights) self.fill_conv_block(res.conv2, weights, gammas) self.fill_conv_block(res.conv1, weights, gammas) self.fill_conv_block(self.pb.weights.input, weights, gammas) def print_pb_stats(obj, parent=None): for descriptor in obj.DESCRIPTOR.fields: value = getattr(obj, descriptor.name) if descriptor.name == "weights": return if descriptor.type == descriptor.TYPE_MESSAGE: if descriptor.label == descriptor.LABEL_REPEATED: map(print_pb_stats, value) else: print_pb_stats(value, obj) elif descriptor.type == descriptor.TYPE_ENUM: enum_name = descriptor.enum_type.values[value].name print("%s: %s" % (descriptor.full_name, enum_name)) else: print("%s: %s" % (descriptor.full_name, value)) def main(argv): net = Net() if argv.input.endswith(".txt"): print('Found .txt network') net.parse_txt(argv.input) net.print_stats() if argv.output == None: argv.output = argv.input.replace('.txt', '.pb.gz') assert argv.output.endswith('.pb.gz') print('Writing output to: {}'.format(argv.output)) net.save_proto(argv.output) elif argv.input.endswith(".pb.gz"): print('Found .pb.gz network') net.parse_proto(argv.input) net.print_stats() if argv.output == None: argv.output = argv.input.replace('.pb.gz', '.txt.gz') print('Writing output to: {}'.format(argv.output)) assert argv.output.endswith('.txt.gz') if argv.output.endswith(".pb.gz"): net.save_proto(argv.output) else: net.save_txt(argv.output) else: print('Unable to detect the network format. ' 'Filename should end in ".txt" or ".pb.gz"') if __name__ == "__main__": argparser = argparse.ArgumentParser( description='Convert network textfile to proto.') argparser.add_argument('-i', '--input', type=str, help='input network weight text file') argparser.add_argument('-o', '--output', type=str, help='output filepath without extension') main(argparser.parse_args()) ================================================ FILE: tf/net_to_model.py ================================================ #!/usr/bin/env python3 import argparse import os import yaml import tfprocess argparser = argparse.ArgumentParser(description='Convert net to model.') argparser.add_argument('net', type=str, help='Net file to be converted to a model checkpoint.') argparser.add_argument('--start', type=int, default=0, help='Offset to set global_step to.') argparser.add_argument('--cfg', type=argparse.FileType('r'), help='yaml configuration with training parameters') argparser.add_argument('-e', '--ignore-errors', action='store_true', help='Ignore missing and wrong sized values.') args = argparser.parse_args() cfg = yaml.safe_load(args.cfg.read()) print(yaml.dump(cfg, default_flow_style=False)) START_FROM = args.start tfp = tfprocess.TFProcess(cfg) tfp.init_net() tfp.replace_weights(args.net, args.ignore_errors) tfp.global_step.assign(START_FROM) root_dir = os.path.join(cfg['training']['path'], cfg['name']) if not os.path.exists(root_dir): os.makedirs(root_dir) tfp.manager.save(checkpoint_number=START_FROM) print("Wrote model to {}".format(tfp.manager.latest_checkpoint)) ================================================ FILE: tf/policy_index.py ================================================ policy_index = [ "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2", "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6", "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1", "b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3", "b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7", "b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1", "c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3", "c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8", "d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2", "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4", "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1", "e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2", "e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4", "e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1", "f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3", "f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6", "f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1", "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3", "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8", "h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2", "h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6", "h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2", "a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3", "a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7", "a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2", "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4", "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7", "b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2", "c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3", "c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6", "c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1", "d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3", "d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5", "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1", "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2", "e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4", "e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1", "f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2", "f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4", "f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7", "f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2", "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4", "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8", "g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2", "h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5", "h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1", "a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3", "a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6", "a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1", "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3", "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5", "b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1", "c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3", "c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4", "c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6", "c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1", "d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3", "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5", "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7", "d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2", "e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3", "e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5", "e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1", "f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3", "f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4", "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6", "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2", "g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3", "g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5", "g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1", "h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3", "h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6", "h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2", "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4", "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7", "a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3", "b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4", "b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6", "b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2", "c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4", "c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5", "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8", "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2", "d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4", "d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6", "d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8", "e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3", "e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4", "e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6", "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1", "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3", "f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4", "f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6", "f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2", "g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4", "g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6", "g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1", "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4", "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6", "h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2", "a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5", "a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7", "a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3", "b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5", "b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7", "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2", "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4", "c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6", "c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7", "c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3", "d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5", "d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6", "d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8", "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3", "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5", "e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6", "e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8", "f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3", "f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5", "f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7", "f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2", "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4", "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6", "g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1", "h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4", "h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6", "h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2", "a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5", "a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7", "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3", "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5", "b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7", "b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2", "c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5", "c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6", "c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8", "c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3", "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5", "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7", "d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8", "e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4", "e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6", "e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7", "e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2", "f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5", "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6", "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8", "f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3", "g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6", "g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7", "g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2", "h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6", "h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7", "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3", "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7", "a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8", "b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5", "b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7", "b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8", "c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5", "c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7", "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8", "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4", "d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6", "d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8", "d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4", "e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6", "e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7", "e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2", "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5", "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7", "f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1", "g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5", "g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7", "g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1", "h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5", "h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7", "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2", "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6", "a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8", "a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5", "b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7", "b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2", "c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6", "c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8", "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4", "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6", "d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8", "d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5", "e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7", "e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8", "e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5", "f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7", "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1", "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6", "g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8", "g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2", "h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6", "h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8", "h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q", "b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b", "c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r", "c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q", "d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b", "e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r", "f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q", "g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b", "h7h8q", "h7h8r", "h7h8b" ] ================================================ FILE: tf/requirements.txt ================================================ numpy==1.13.3 tensorflow==2.5.1 tensorflow-tensorboard==0.4.0rc2 protobuf==3.12.1 ================================================ FILE: tf/shufflebuffer.py ================================================ #!/usr/bin/env python3 # # This file is part of Leela Chess. # Copyright (C) 2018 Michael O # # Leela Chess is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Leela Chess is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Leela Chess. If not, see . import random import unittest class ShuffleBuffer: def __init__(self, elem_size, elem_count): """ A shuffle buffer for fixed sized elements. Manages 'elem_count' items in a fixed buffer, each item being exactly 'elem_size' bytes. """ assert elem_size > 0, elem_size assert elem_count > 0, elem_count # Size of each element. self.elem_size = elem_size # Number of elements in the buffer. self.elem_count = elem_count # Fixed size buffer used to hold all the element. self.buffer = bytearray(elem_size * elem_count) # Number of elements actually contained in the buffer. self.used = 0 def extract(self): """ Return an item from the shuffle buffer. If the buffer is empty, returns None """ if self.used < 1: return None # The items in the shuffle buffer are held in shuffled order # so returning the last item is sufficient. self.used -= 1 i = self.used return self.buffer[i * self.elem_size:(i + 1) * self.elem_size] def insert_or_replace(self, item): """ Inserts 'item' into the shuffle buffer, returning a random item. If the buffer is not yet full, returns None """ assert len(item) == self.elem_size, len(item) # putting the new item in a random location, and appending # the displaced item to the end of the buffer achieves a full # random shuffle (Fisher-Yates) if self.used > 0: # swap 'item' with random item in buffer. i = random.randint(0, self.used - 1) old_item = self.buffer[i * self.elem_size:(i + 1) * self.elem_size] self.buffer[i * self.elem_size:(i + 1) * self.elem_size] = item item = old_item # If the buffer isn't yet full, append 'item' to the end of the buffer. if self.used < self.elem_count: # Not yet full, so place the returned item at the end of the buffer. i = self.used self.buffer[i * self.elem_size:(i + 1) * self.elem_size] = item self.used += 1 return None return item class ShuffleBufferTest(unittest.TestCase): def test_extract(self): sb = ShuffleBuffer(3, 1) r = sb.extract() assert r == None, r # empty buffer => None r = sb.insert_or_replace(b'111') assert r == None, r # buffer not yet full => None r = sb.extract() assert r == b'111', r # one item in buffer => item r = sb.extract() assert r == None, r # buffer empty => None def test_wrong_size(self): sb = ShuffleBuffer(3, 1) try: sb.insert_or_replace(b'1') # wrong length, so should throw. assert False # Should not be reached. except: pass def test_insert_or_replace(self): n = 10 # number of test items. items = [bytes([x, x, x]) for x in range(n)] sb = ShuffleBuffer(elem_size=3, elem_count=2) out = [] for i in items: r = sb.insert_or_replace(i) if not r is None: out.append(r) # Buffer size is 2, 10 items, should be 8 seen so far. assert len(out) == n - 2, len(out) # Get the last two items. out.append(sb.extract()) out.append(sb.extract()) assert sorted(items) == sorted(out), (items, out) # Check that buffer is empty r = sb.extract() assert r is None, r if __name__ == '__main__': unittest.main() ================================================ FILE: tf/start.sh ================================================ #!/usr/bin/env bash set -e NETARCHS=(64x6) REPO="origin" ROOT="/work/lc0" NETDIR="$ROOT/networks/upload" GAMEFILE="$HOME/.lc0.dat" LATESTFILE="$HOME/.lc0.latest.dat" RAMDISK="/ramdisk" MIN_GAP=10 function usage() { echo "Starts a training pipeline" echo "" echo "./start.sh" echo " -h --help" echo " -c --cfg The configuration directory" echo " -g --games The number of games between training cycles" echo " -b --branch The git branch to push configs to" echo "" echo "Example: ./start.sh -c=/tmp/cfgdir -g=40000 -b=test" echo "" } while [ "$1" != "" ] do PARAM=`echo $1 | awk -F= '{print $1}'` VALUE=`echo $1 | awk -F= '{print $2}'` case $PARAM in -h | --help) usage exit ;; -c | --cfg) CONFIGDIR=$VALUE ;; -g | --games) GAMES=$VALUE ;; *) echo "ERROR: unknown parameter \"$PARAM\"" usage exit 1 ;; esac shift done if [ ! -f "$GAMEFILE" ] then echo "File $GAMEFILE must contain a single number, exiting now!" exit 1 fi if [ -z "$LC0LOCKFILE" ] then echo "env var LC0LOCKFILE not set" exit 1 fi game_num=$(cat $GAMEFILE) game_num=$((game_num + GAMES)) file="training.${game_num}.gz" echo "Starting with '$file' as last game in window" train() { unbuffer ./train.py --cfg=$1 --output=$2 2>&1 | tee "$ROOT/logs/$(date +%Y%m%d-%H%M%S).log" mv -v $2.pb.gz $NETDIR } delay_count=$((MIN_GAP+1)) while true do latest_num=0 if [ -f "$LATESTFILE" ] then latest_num=$(cat $LATESTFILE) fi if [[ $delay_count -gt $MIN_GAP && ( $latest_num -gt $game_num || -f "$ROOT/data-rescored/$file" ) ]] then if [ $latest_num -gt $game_num ] then game_num=$latest_num fi echo "" # prepare ramdisk ( flock -e 200 rsync -aq --delete-during $ROOT/split/{train,test} $RAMDISK ) 200>$LC0LOCKFILE # train all networks for netarch in ${NETARCHS[@]} do echo "Training $netarch:" train "$CONFIGDIR/$netarch.yaml" "${netarch}-$(date +"%Y_%m%d_%H%M_%S_%3N")" done # wait for next cycle echo $game_num > $GAMEFILE game_num=$((game_num + GAMES)) file="training.${game_num}.gz" echo "Waiting for '$file'" delay_count=1 else echo -n "." sleep 60 delay_count=$((delay_count+1)) fi done ================================================ FILE: tf/tfprocess.py ================================================ #!/usr/bin/env python3 # # This file is part of Leela Zero. # Copyright (C) 2017-2018 Gian-Carlo Pascutto # # Leela Zero is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Leela Zero is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Leela Zero. If not, see . import numpy as np import os import random import tensorflow as tf import time import bisect import lc0_az_policy_map import attention_policy_map as apm import proto.net_pb2 as pb from functools import reduce import operator from net import Net def square_relu(x): return tf.nn.relu(x)**2 class Gating(tf.keras.layers.Layer): def __init__(self, name=None, additive=True, init_value=None, **kwargs): self.additive = additive if init_value is None: init_value = 0 if self.additive else 1 self.init_value = init_value super().__init__(name=name, **kwargs) def build(self, input_shape): self.gate = self.add_weight(name='gate', shape=input_shape[1:], constraint=tf.keras.constraints.NonNeg() if not self.additive else None, initializer=tf.constant_initializer( self.init_value), trainable=True) def call(self, inputs): return tf.add(inputs, self.gate) if self.additive else tf.multiply( inputs, self.gate) def ma_gating(inputs, name): out = Gating(name=name + '/mult_gate', additive=False)(inputs) out = Gating(name=name + '/add_gate', additive=True)(out) return out class ApplySqueezeExcitation(tf.keras.layers.Layer): def __init__(self, **kwargs): super(ApplySqueezeExcitation, self).__init__(**kwargs) def build(self, input_dimens): self.reshape_size = input_dimens[1][1] def call(self, inputs): x = inputs[0] excited = inputs[1] gammas, betas = tf.split(tf.reshape(excited, [-1, self.reshape_size, 1, 1]), 2, axis=1) return tf.nn.sigmoid(gammas) * x + betas class ApplyPolicyMap(tf.keras.layers.Layer): def __init__(self, **kwargs): super(ApplyPolicyMap, self).__init__(**kwargs) self.fc1 = tf.constant(lc0_az_policy_map.make_map()) def call(self, inputs): h_conv_pol_flat = tf.reshape(inputs, [-1, 80 * 8 * 8]) return tf.matmul(h_conv_pol_flat, tf.cast(self.fc1, h_conv_pol_flat.dtype)) class ApplyAttentionPolicyMap(tf.keras.layers.Layer): def __init__(self, **kwargs): super(ApplyAttentionPolicyMap, self).__init__(**kwargs) self.fc1 = tf.constant(apm.make_map()) def call(self, logits, pp_logits): logits = tf.concat([ tf.reshape(logits, [-1, 64 * 64]), tf.reshape(pp_logits, [-1, 8 * 24]) ], axis=1) return tf.matmul(logits, tf.cast(self.fc1, logits.dtype)) class Metric: def __init__(self, short_name, long_name, suffix='', **kwargs): self.short_name = short_name self.long_name = long_name self.suffix = suffix self.value = 0.0 self.count = 0 def assign(self, value): self.value = value self.count = 1 def accumulate(self, value): if self.count > 0: self.value = self.value + value self.count = self.count + 1 else: self.assign(value) def merge(self, other): assert self.short_name == other.short_name self.value = self.value + other.value self.count = self.count + other.count def get(self): if self.count == 0: return self.value return self.value / self.count def reset(self): self.value = 0.0 self.count = 0 class TFProcess: def __init__(self, cfg): self.cfg = cfg self.net = Net() self.root_dir = os.path.join(self.cfg['training']['path'], self.cfg['name']) # Network structure self.RESIDUAL_FILTERS = self.cfg['model'].get('filters', 0) self.RESIDUAL_BLOCKS = self.cfg['model'].get('residual_blocks', 0) self.SE_ratio = self.cfg['model'].get('se_ratio', 0) self.encoder_layers = self.cfg['model'].get('encoder_layers', 0) self.encoder_heads = self.cfg['model'].get('encoder_heads', 2) assert (self.RESIDUAL_BLOCKS > 0) != (self.encoder_layers > 0), \ "Nets with both encoder layers and residual blocks are not supported" if self.encoder_layers > 0: self.RESIDUAL_FILTERS = self.cfg['model']['embedding_size'] self.embedding_size = self.RESIDUAL_FILTERS self.policy_channels = self.cfg['model'].get('policy_channels', 32) self.pol_embedding_size = self.cfg['model'].get( 'pol_embedding_size', self.RESIDUAL_FILTERS) self.val_embedding_size = self.cfg['model'].get( 'value_embedding_size', 32) self.mov_embedding_size = self.cfg['model'].get( 'moves_left_embedding_size', 8) #policy head self.pol_encoder_layers = (0 if self.encoder_layers > 0 else 1) #logic is to explictly warn users who set both in yaml if self.cfg['model'].get('pol_encoder_layers') is not None: self.pol_encoder_layers = self.cfg['model'].get( 'pol_encoder_layers') assert not ((self.pol_encoder_layers > 0) and (self.encoder_layers > 0)), \ "Nets with both body encoder layers and policy encoder layers are not supported" self.pol_encoder_heads = self.cfg['model'].get('pol_encoder_heads', 2) self.pol_encoder_d_model = self.cfg['model'].get( 'pol_encoder_d_model', self.RESIDUAL_FILTERS) self.pol_encoder_dff = self.cfg['model'].get( 'pol_encoder_dff', (self.RESIDUAL_FILTERS * 1.5) // 1) self.policy_d_model = self.cfg['model'].get('policy_d_model', self.RESIDUAL_FILTERS) #encoder body self.input_gate = self.cfg['model'].get('input_gate') self.encoder_d_model = self.cfg['model'].get('encoder_d_model') self.encoder_dff = self.cfg['model'].get( 'encoder_dff', (self.RESIDUAL_FILTERS * 1.5) // 1) self.policy_d_model = self.cfg['model'].get('policy_d_model', self.RESIDUAL_FILTERS) self.arc_encoding = self.cfg['model'].get('arc_encoding', True) self.square_relu_ffn = self.cfg['model'].get('square_relu_ffn', False) self.use_smolgen = self.cfg['model'].get('use_smolgen', False) self.smolgen_hidden_channels = self.cfg['model'].get( 'smolgen_hidden_channels', 16) self.smolgen_hidden_sz = self.cfg['model'].get('smolgen_hidden_sz', 128) self.smolgen_gen_sz = self.cfg['model'].get('smolgen_gen_sz', 128) self.smolgen_activation = self.cfg['model'].get( 'smolgen_activation', 'swish') self.dropout_rate = self.cfg['model'].get('dropout_rate', 0.0) precision = self.cfg['training'].get('precision', 'single') loss_scale = self.cfg['training'].get('loss_scale', 128) self.virtual_batch_size = self.cfg['model'].get( 'virtual_batch_size', None) if precision == 'single': self.model_dtype = tf.float32 elif precision == 'half': self.model_dtype = tf.float16 else: raise ValueError("Unknown precision: {}".format(precision)) # Scale the loss to prevent gradient underflow self.loss_scale = 1 if self.model_dtype == tf.float32 else loss_scale policy_head = self.cfg['model'].get('policy', 'convolution') value_head = self.cfg['model'].get('value', 'wdl') moves_left_head = self.cfg['model'].get('moves_left', 'v1') input_mode = self.cfg['model'].get('input_type', 'classic') default_activation = self.cfg['model'].get('default_activation', 'relu') self.POLICY_HEAD = None self.VALUE_HEAD = None self.MOVES_LEFT_HEAD = None self.INPUT_MODE = None self.DEFAULT_ACTIVATION = None if policy_head == "classical": self.POLICY_HEAD = pb.NetworkFormat.POLICY_CLASSICAL elif policy_head == "convolution": self.POLICY_HEAD = pb.NetworkFormat.POLICY_CONVOLUTION elif policy_head == "attention": self.POLICY_HEAD = pb.NetworkFormat.POLICY_ATTENTION if self.pol_encoder_layers > 0: self.net.set_pol_headcount(self.pol_encoder_heads) else: raise ValueError( "Unknown policy head format: {}".format(policy_head)) self.net.set_policyformat(self.POLICY_HEAD) if value_head == "classical": self.VALUE_HEAD = pb.NetworkFormat.VALUE_CLASSICAL self.wdl = False elif value_head == "wdl": self.VALUE_HEAD = pb.NetworkFormat.VALUE_WDL self.wdl = True else: raise ValueError( "Unknown value head format: {}".format(value_head)) self.net.set_valueformat(self.VALUE_HEAD) if moves_left_head == "none": self.MOVES_LEFT_HEAD = pb.NetworkFormat.MOVES_LEFT_NONE self.moves_left = False elif moves_left_head == "v1": self.MOVES_LEFT_HEAD = pb.NetworkFormat.MOVES_LEFT_V1 self.moves_left = True else: raise ValueError( "Unknown moves left head format: {}".format(moves_left_head)) self.net.set_movesleftformat(self.MOVES_LEFT_HEAD) if input_mode == "classic": self.INPUT_MODE = pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE elif input_mode == "frc_castling": self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CASTLING_PLANE elif input_mode == "canonical": self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION elif input_mode == "canonical_100": self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES elif input_mode == "canonical_armageddon": self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON elif input_mode == "canonical_v2": self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2 elif input_mode == "canonical_v2_armageddon": self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON else: raise ValueError( "Unknown input mode format: {}".format(input_mode)) self.net.set_input(self.INPUT_MODE) if default_activation == "relu": self.net.set_defaultactivation( pb.NetworkFormat.DEFAULT_ACTIVATION_RELU) self.DEFAULT_ACTIVATION = 'relu' elif default_activation == "mish": self.net.set_defaultactivation( pb.NetworkFormat.DEFAULT_ACTIVATION_MISH) try: self.DEFAULT_ACTIVATION = tf.keras.activations.mish except AttributeError: import tensorflow_addons as tfa self.DEFAULT_ACTIVATION = tfa.activations.mish else: raise ValueError("Unknown default activation type: {}".format( default_activation)) if self.encoder_layers > 0: self.net.set_headcount(self.encoder_heads) self.net.set_networkformat( pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT) self.net.set_smolgen_activation( self.net.activation(self.smolgen_activation)) self.net.set_ffn_activation( self.net.activation( 'sqrrelu' if self.square_relu_ffn else 'default')) self.swa_enabled = self.cfg['training'].get('swa', False) # Limit momentum of SWA exponential average to 1 - 1/(swa_max_n + 1) self.swa_max_n = self.cfg['training'].get('swa_max_n', 0) self.renorm_enabled = self.cfg['training'].get('renorm', False) self.renorm_max_r = self.cfg['training'].get('renorm_max_r', 1) self.renorm_max_d = self.cfg['training'].get('renorm_max_d', 0) self.renorm_momentum = self.cfg['training'].get( 'renorm_momentum', 0.99) if self.cfg['gpu'] == 'all': gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) self.strategy = tf.distribute.MirroredStrategy() tf.distribute.experimental_set_strategy(self.strategy) else: gpus = tf.config.experimental.list_physical_devices('GPU') print(gpus) tf.config.experimental.set_visible_devices(gpus[self.cfg['gpu']], 'GPU') tf.config.experimental.set_memory_growth(gpus[self.cfg['gpu']], True) self.strategy = None if self.model_dtype == tf.float16: tf.keras.mixed_precision.experimental.set_policy('mixed_float16') self.global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int64) def init(self, train_dataset, test_dataset, validation_dataset=None): if self.strategy is not None: self.train_dataset = self.strategy.experimental_distribute_dataset( train_dataset) else: self.train_dataset = train_dataset self.train_iter = iter(self.train_dataset) if self.strategy is not None: self.test_dataset = self.strategy.experimental_distribute_dataset( test_dataset) else: self.test_dataset = test_dataset self.test_iter = iter(self.test_dataset) if self.strategy is not None and validation_dataset is not None: self.validation_dataset = self.strategy.experimental_distribute_dataset( validation_dataset) else: self.validation_dataset = validation_dataset if self.strategy is not None: this = self with self.strategy.scope(): this.init_net() else: self.init_net() def init_net(self): self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001)) input_var = tf.keras.Input(shape=(112, 8, 8)) outputs = self.construct_net(input_var) self.model = tf.keras.Model(inputs=input_var, outputs=outputs) # swa_count initialized regardless to make checkpoint code simpler. self.swa_count = tf.Variable(0., name='swa_count', trainable=False) self.swa_weights = None if self.swa_enabled: # Count of networks accumulated into SWA self.swa_weights = [ tf.Variable(w, trainable=False) for w in self.model.weights ] self.active_lr = tf.Variable(0.01, trainable=False) # All 'new' (TF 2.10 or newer non-legacy) optimizers must have learning_rate updated manually. self.update_lr_manually = False # Be sure not to set new_optimizer before TF 2.11, or unless you edit the code to specify a new optimizer explicitly. if self.cfg['training'].get('new_optimizer'): self.optimizer = tf.keras.optimizers.SGD( learning_rate=self.active_lr, momentum=0.9, nesterov=True) self.update_lr_manually = True else: try: self.optimizer = tf.keras.optimizers.legacy.SGD( learning_rate=lambda: self.active_lr, momentum=0.9, nesterov=True) except AttributeError: self.optimizer = tf.keras.optimizers.SGD( learning_rate=lambda: self.active_lr, momentum=0.9, nesterov=True) self.orig_optimizer = self.optimizer try: self.aggregator = self.orig_optimizer.aggregate_gradients except AttributeError: self.aggregator = self.orig_optimizer.gradient_aggregator if self.loss_scale != 1: self.optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( self.optimizer, self.loss_scale) if self.cfg['training'].get('lookahead_optimizer'): import tensorflow_addons as tfa self.optimizer = tfa.optimizers.Lookahead(self.optimizer) def correct_policy(target, output): output = tf.cast(output, tf.float32) # Calculate loss on policy head if self.cfg['training'].get('mask_legal_moves'): # extract mask for legal moves from target policy move_is_legal = tf.greater_equal(target, 0) # replace logits of illegal moves with large negative value (so that it doesn't affect policy of legal moves) without gradient illegal_filler = tf.zeros_like(output) - 1.0e10 output = tf.where(move_is_legal, output, illegal_filler) # y_ still has -1 on illegal moves, flush them to 0 target = tf.nn.relu(target) return target, output def policy_loss(target, output): target, output = correct_policy(target, output) policy_cross_entropy = tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(target), logits=output) return tf.reduce_mean(input_tensor=policy_cross_entropy) self.policy_loss_fn = policy_loss def policy_accuracy(target, output): target, output = correct_policy(target, output) return tf.reduce_mean( tf.cast( tf.equal(tf.argmax(input=target, axis=1), tf.argmax(input=output, axis=1)), tf.float32)) self.policy_accuracy_fn = policy_accuracy def moves_left_mean_error_fn(target, output): output = tf.cast(output, tf.float32) return tf.reduce_mean(tf.abs(target - output)) self.moves_left_mean_error = moves_left_mean_error_fn def policy_entropy(target, output): target, output = correct_policy(target, output) softmaxed = tf.nn.softmax(output) return tf.math.negative( tf.reduce_mean( tf.reduce_sum(tf.math.xlogy(softmaxed, softmaxed), axis=1))) self.policy_entropy_fn = policy_entropy def policy_uniform_loss(target, output): uniform = tf.where(tf.greater_equal(target, 0), tf.ones_like(target), tf.zeros_like(target)) balanced_uniform = uniform / tf.reduce_sum( uniform, axis=1, keepdims=True) target, output = correct_policy(target, output) policy_cross_entropy = \ tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(balanced_uniform), logits=output) return tf.reduce_mean(input_tensor=policy_cross_entropy) self.policy_uniform_loss_fn = policy_uniform_loss q_ratio = self.cfg['training'].get('q_ratio', 0) assert 0 <= q_ratio <= 1 # Linear conversion to scalar to compute MSE with, for comparison to old values wdl = tf.expand_dims(tf.constant([1.0, 0.0, -1.0]), 1) self.qMix = lambda z, q: q * q_ratio + z * (1 - q_ratio) # Loss on value head if self.wdl: def value_loss(target, output): output = tf.cast(output, tf.float32) value_cross_entropy = tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(target), logits=output) return tf.reduce_mean(input_tensor=value_cross_entropy) self.value_loss_fn = value_loss def mse_loss(target, output): output = tf.cast(output, tf.float32) scalar_z_conv = tf.matmul(tf.nn.softmax(output), wdl) scalar_target = tf.matmul(target, wdl) return tf.reduce_mean(input_tensor=tf.math.squared_difference( scalar_target, scalar_z_conv)) self.mse_loss_fn = mse_loss else: def value_loss(target, output): return tf.constant(0) self.value_loss_fn = value_loss def mse_loss(target, output): output = tf.cast(output, tf.float32) scalar_target = tf.matmul(target, wdl) return tf.reduce_mean(input_tensor=tf.math.squared_difference( scalar_target, output)) self.mse_loss_fn = mse_loss if self.moves_left: def moves_left_loss(target, output): # Scale the loss to similar range as other losses. scale = 20.0 target = target / scale output = tf.cast(output, tf.float32) / scale if self.strategy is not None: huber = tf.keras.losses.Huber( 10.0 / scale, reduction=tf.keras.losses.Reduction.NONE) else: huber = tf.keras.losses.Huber(10.0 / scale) return tf.reduce_mean(huber(target, output)) else: moves_left_loss = None self.moves_left_loss_fn = moves_left_loss pol_loss_w = self.cfg['training']['policy_loss_weight'] val_loss_w = self.cfg['training']['value_loss_weight'] if self.moves_left: moves_loss_w = self.cfg['training']['moves_left_loss_weight'] else: moves_loss_w = tf.constant(0.0, dtype=tf.float32) reg_term_w = self.cfg['training'].get('reg_term_weight', 1.0) def _lossMix(policy, value, moves_left, reg_term): return pol_loss_w * policy + val_loss_w * value + moves_loss_w * moves_left + reg_term_w * reg_term self.lossMix = _lossMix def accuracy(target, output): output = tf.cast(output, tf.float32) return tf.reduce_mean( tf.cast( tf.equal(tf.argmax(input=target, axis=1), tf.argmax(input=output, axis=1)), tf.float32)) self.accuracy_fn = accuracy # Order must match the order in process_inner_loop self.train_metrics = [ Metric('P', 'Policy Loss'), Metric('V', 'Value Loss'), Metric('ML', 'Moves Left Loss'), Metric('Reg', 'Reg term'), Metric('Total', 'Total Loss'), Metric( 'V MSE', 'MSE Loss' ), # Long name here doesn't mention value for backwards compatibility reasons. ] self.time_start = None self.last_steps = None # Order must match the order in calculate_test_summaries_inner_loop self.test_metrics = [ Metric('P', 'Policy Loss'), Metric('V', 'Value Loss'), Metric('ML', 'Moves Left Loss'), Metric( 'V MSE', 'MSE Loss' ), # Long name here doesn't mention value for backwards compatibility reasons. Metric('P Acc', 'Policy Accuracy', suffix='%'), Metric('V Acc', 'Value Accuracy', suffix='%'), Metric('ML Mean', 'Moves Left Mean Error'), Metric('P Entropy', 'Policy Entropy'), Metric('P UL', 'Policy UL'), ] # Set adaptive learning rate during training self.cfg['training']['lr_boundaries'].sort() self.warmup_steps = self.cfg['training'].get('warmup_steps', 0) self.lr = self.cfg['training']['lr_values'][0] self.test_writer = tf.summary.create_file_writer( os.path.join(os.getcwd(), "leelalogs/{}-test".format(self.cfg['name']))) self.train_writer = tf.summary.create_file_writer( os.path.join(os.getcwd(), "leelalogs/{}-train".format(self.cfg['name']))) if vars(self).get('validation_dataset', None) is not None: self.validation_writer = tf.summary.create_file_writer( os.path.join( os.getcwd(), "leelalogs/{}-validation".format(self.cfg['name']))) if self.swa_enabled: self.swa_writer = tf.summary.create_file_writer( os.path.join(os.getcwd(), "leelalogs/{}-swa-test".format(self.cfg['name']))) self.swa_validation_writer = tf.summary.create_file_writer( os.path.join( os.getcwd(), "leelalogs/{}-swa-validation".format(self.cfg['name']))) self.checkpoint = tf.train.Checkpoint(optimizer=self.orig_optimizer, model=self.model, global_step=self.global_step, swa_count=self.swa_count) self.checkpoint.listed = self.swa_weights self.manager = tf.train.CheckpointManager( self.checkpoint, directory=self.root_dir, max_to_keep=50, keep_checkpoint_every_n_hours=24, checkpoint_name=self.cfg['name']) def replace_weights(self, proto_filename, ignore_errors=False): self.net.parse_proto(proto_filename) filters, blocks = self.net.filters(), self.net.blocks() if not ignore_errors: if self.RESIDUAL_FILTERS != filters: raise ValueError("Number of filters doesn't match the network") if self.RESIDUAL_BLOCKS != blocks: raise ValueError("Number of blocks doesn't match the network") if self.POLICY_HEAD != self.net.pb.format.network_format.policy: raise ValueError("Policy head type doesn't match the network") if self.VALUE_HEAD != self.net.pb.format.network_format.value: raise ValueError("Value head type doesn't match the network") # List all tensor names we need weights for. names = [] for weight in self.model.weights: names.append(weight.name) new_weights = self.net.get_weights_v2(names) for weight in self.model.weights: if 'renorm' in weight.name: # Renorm variables are not populated. continue try: new_weight = new_weights[weight.name] except KeyError: error_string = 'No values for tensor {} in protobuf'.format( weight.name) if ignore_errors: print(error_string) continue else: raise KeyError(error_string) if reduce(operator.mul, weight.shape.as_list(), 1) != len(new_weight): error_string = 'Tensor {} has wrong length. Tensorflow shape {}, size in protobuf {}'.format( weight.name, weight.shape.as_list(), len(new_weight)) if ignore_errors: print(error_string) continue else: raise KeyError(error_string) if weight.shape.ndims == 4: # Rescale rule50 related weights as clients do not normalize the input. if weight.name == 'input/conv2d/kernel:0' and self.net.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES: num_inputs = 112 # 50 move rule is the 110th input, or 109 starting from 0. rule50_input = 109 for i in range(len(new_weight)): if (i % (num_inputs * 9)) // 9 == rule50_input: new_weight[i] = new_weight[i] * 99 # Convolution weights need a transpose # # TF (kYXInputOutput) # [filter_height, filter_width, in_channels, out_channels] # # Leela/cuDNN/Caffe (kOutputInputYX) # [output, input, filter_size, filter_size] s = weight.shape.as_list() shape = [s[i] for i in [3, 2, 0, 1]] new_weight = tf.constant(new_weight, shape=shape) weight.assign(tf.transpose(a=new_weight, perm=[2, 3, 1, 0])) elif weight.shape.ndims == 2: # Fully connected layers are [in, out] in TF # # [out, in] in Leela # s = weight.shape.as_list() shape = [s[i] for i in [1, 0]] new_weight = tf.constant(new_weight, shape=shape) weight.assign(tf.transpose(a=new_weight, perm=[1, 0])) else: # Biases, batchnorm etc new_weight = tf.constant(new_weight, shape=weight.shape) weight.assign(new_weight) # Replace the SWA weights as well, ensuring swa accumulation is reset. if self.swa_enabled: self.swa_count.assign(tf.constant(0.)) self.update_swa() # This should result in identical file to the starting one # self.save_leelaz_weights('restored.pb.gz') def restore(self): if self.manager.latest_checkpoint is not None: print("Restoring from {0}".format(self.manager.latest_checkpoint)) self.checkpoint.restore(self.manager.latest_checkpoint) def process_loop(self, batch_size, test_batches, batch_splits=1): if self.swa_enabled: # split half of test_batches between testing regular weights and SWA weights test_batches //= 2 # Make sure that ghost batch norm can be applied if self.virtual_batch_size and batch_size % self.virtual_batch_size != 0: # Adjust required batch size for batch splitting. required_factor = self.virtual_batch_size * self.cfg[ 'training'].get('num_batch_splits', 1) raise ValueError( 'batch_size must be a multiple of {}'.format(required_factor)) # Get the initial steps value in case this is a resume from a step count # which is not a multiple of total_steps. steps = self.global_step.read_value() self.last_steps = steps self.time_start = time.time() self.profiling_start_step = None total_steps = self.cfg['training']['total_steps'] for _ in range(steps % total_steps, total_steps): self.process(batch_size, test_batches, batch_splits=batch_splits) @tf.function() def read_weights(self): return [w.read_value() for w in self.model.weights] @tf.function() def process_inner_loop(self, x, y, z, q, m): with tf.GradientTape() as tape: outputs = self.model(x, training=True) policy = outputs[0] value = outputs[1] policy_loss = self.policy_loss_fn(y, policy) reg_term = sum(self.model.losses) if self.wdl: value_ce_loss = self.value_loss_fn(self.qMix(z, q), value) value_loss = value_ce_loss else: value_mse_loss = self.mse_loss_fn(self.qMix(z, q), value) value_loss = value_mse_loss if self.moves_left: moves_left = outputs[2] moves_left_loss = self.moves_left_loss_fn(m, moves_left) else: moves_left_loss = tf.constant(0.) total_loss = self.lossMix(policy_loss, value_loss, moves_left_loss, reg_term) if self.loss_scale != 1: total_loss = self.optimizer.get_scaled_loss(total_loss) if self.wdl: mse_loss = self.mse_loss_fn(self.qMix(z, q), value) else: value_loss = self.value_loss_fn(self.qMix(z, q), value) metrics = [ policy_loss, value_loss, moves_left_loss, reg_term, total_loss, # Google's paper scales MSE by 1/4 to a [0, 1] range, so do the same to # get comparable values. mse_loss / 4.0, ] return metrics, tape.gradient(total_loss, self.model.trainable_weights) @tf.function() def strategy_process_inner_loop(self, x, y, z, q, m): metrics, new_grads = self.strategy.run(self.process_inner_loop, args=(x, y, z, q, m)) metrics = [ self.strategy.reduce(tf.distribute.ReduceOp.MEAN, m, axis=None) for m in metrics ] return metrics, new_grads def apply_grads(self, grads, effective_batch_splits): grads = [ g[0] for g in self.aggregator(zip(grads, self.model.trainable_weights)) ] if self.loss_scale != 1: grads = self.optimizer.get_unscaled_gradients(grads) max_grad_norm = self.cfg['training'].get( 'max_grad_norm', 10000.0) * effective_batch_splits grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm) self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights), experimental_aggregate_gradients=False) return grad_norm @tf.function() def strategy_apply_grads(self, grads, effective_batch_splits): grad_norm = self.strategy.run(self.apply_grads, args=(grads, effective_batch_splits)) grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, grad_norm, axis=None) return grad_norm @tf.function() def merge_grads(self, grads, new_grads): return [tf.math.add(a, b) for (a, b) in zip(grads, new_grads)] @tf.function() def strategy_merge_grads(self, grads, new_grads): return self.strategy.run(self.merge_grads, args=(grads, new_grads)) def train_step(self, steps, batch_size, batch_splits): # need to add 1 to steps because steps will be incremented after gradient update if (steps + 1) % self.cfg['training']['train_avg_report_steps'] == 0 or ( steps + 1) % self.cfg['training']['total_steps'] == 0: before_weights = self.read_weights() # Run training for this batch grads = None for _ in range(batch_splits): x, y, z, q, m = next(self.train_iter) if self.strategy is not None: metrics, new_grads = self.strategy_process_inner_loop( x, y, z, q, m) else: metrics, new_grads = self.process_inner_loop(x, y, z, q, m) if not grads: grads = new_grads else: if self.strategy is not None: grads = self.strategy_merge_grads(grads, new_grads) else: grads = self.merge_grads(grads, new_grads) # Keep running averages for acc, val in zip(self.train_metrics, metrics): acc.accumulate(val) # Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this. effective_batch_splits = batch_splits if self.strategy is not None: effective_batch_splits = batch_splits * self.strategy.num_replicas_in_sync self.active_lr.assign(self.lr / effective_batch_splits) if self.update_lr_manually: self.orig_optimizer.learning_rate = self.active_lr if self.strategy is not None: grad_norm = self.strategy_apply_grads(grads, effective_batch_splits) else: grad_norm = self.apply_grads(grads, effective_batch_splits) # Note: grads variable at this point has not been unscaled or # had clipping applied. Since no code after this point depends # upon that it seems fine for now. # Update steps. self.global_step.assign_add(1) steps = self.global_step.read_value() if steps % self.cfg['training'][ 'train_avg_report_steps'] == 0 or steps % self.cfg['training'][ 'total_steps'] == 0: time_end = time.time() speed = 0 if self.time_start: elapsed = time_end - self.time_start steps_elapsed = steps - self.last_steps speed = batch_size * (tf.cast(steps_elapsed, tf.float32) / elapsed) print("step {}, lr={:g}".format(steps, self.lr), end='') for metric in self.train_metrics: print(" {}={:g}{}".format(metric.short_name, metric.get(), metric.suffix), end='') print(" ({:g} pos/s)".format(speed)) after_weights = self.read_weights() with self.train_writer.as_default(): for metric in self.train_metrics: tf.summary.scalar(metric.long_name, metric.get(), step=steps) tf.summary.scalar("LR", self.lr, step=steps) tf.summary.scalar("Gradient norm", grad_norm / effective_batch_splits, step=steps) self.compute_update_ratio(before_weights, after_weights, steps) self.train_writer.flush() self.time_start = time_end self.last_steps = steps for metric in self.train_metrics: metric.reset() return steps def process(self, batch_size, test_batches, batch_splits): # Get the initial steps value before we do a training step. steps = self.global_step.read_value() # By default disabled since 0 != 10. if steps % self.cfg['training'].get('profile_step_freq', 1) == self.cfg['training'].get( 'profile_step_offset', 10): self.profiling_start_step = steps tf.profiler.experimental.start( os.path.join(os.getcwd(), "leelalogs/{}-profile".format(self.cfg['name']))) # Run test before first step to see delta since end of last run. if steps % self.cfg['training']['total_steps'] == 0: with tf.profiler.experimental.Trace("Test", step_num=steps + 1): # Steps is given as one higher than current in order to avoid it # being equal to the value the end of a run is stored against. self.calculate_test_summaries(test_batches, steps + 1) if self.swa_enabled: self.calculate_swa_summaries(test_batches, steps + 1) # Determine learning rate lr_values = self.cfg['training']['lr_values'] lr_boundaries = self.cfg['training']['lr_boundaries'] steps_total = steps % self.cfg['training']['total_steps'] self.lr = lr_values[bisect.bisect_right(lr_boundaries, steps_total)] if self.warmup_steps > 0 and steps < self.warmup_steps: self.lr = self.lr * tf.cast(steps + 1, tf.float32) / self.warmup_steps with tf.profiler.experimental.Trace("Train", step_num=steps): steps = self.train_step(steps, batch_size, batch_splits) if self.swa_enabled and steps % self.cfg['training']['swa_steps'] == 0: self.update_swa() # Calculate test values every 'test_steps', but also ensure there is # one at the final step so the delta to the first step can be calculated. if steps % self.cfg['training']['test_steps'] == 0 or steps % self.cfg[ 'training']['total_steps'] == 0: with tf.profiler.experimental.Trace("Test", step_num=steps): self.calculate_test_summaries(test_batches, steps) if self.swa_enabled: self.calculate_swa_summaries(test_batches, steps) if self.validation_dataset is not None and ( steps % self.cfg['training']['validation_steps'] == 0 or steps % self.cfg['training']['total_steps'] == 0): with tf.profiler.experimental.Trace("Validate", step_num=steps): if self.swa_enabled: self.calculate_swa_validations(steps) else: self.calculate_test_validations(steps) # Save session and weights at end, and also optionally every 'checkpoint_steps'. if steps % self.cfg['training']['total_steps'] == 0 or ( 'checkpoint_steps' in self.cfg['training'] and steps % self.cfg['training']['checkpoint_steps'] == 0): evaled_steps = steps.numpy() self.manager.save(checkpoint_number=evaled_steps) print("Model saved in file: {}".format( self.manager.latest_checkpoint)) path = os.path.join(self.root_dir, self.cfg['name']) leela_path = path + "-" + str(evaled_steps) swa_path = path + "-swa-" + str(evaled_steps) self.net.pb.training_params.training_steps = evaled_steps self.save_leelaz_weights(leela_path) if self.swa_enabled: self.save_swa_weights(swa_path) if self.profiling_start_step is not None and ( steps >= self.profiling_start_step + self.cfg['training'].get('profile_step_count', 0) or steps % self.cfg['training']['total_steps'] == 0): tf.profiler.experimental.stop() self.profiling_start_step = None def calculate_swa_summaries(self, test_batches, steps): backup = self.read_weights() for (swa, w) in zip(self.swa_weights, self.model.weights): w.assign(swa.read_value()) true_test_writer, self.test_writer = self.test_writer, self.swa_writer print('swa', end=' ') self.calculate_test_summaries(test_batches, steps) self.test_writer = true_test_writer for (old, w) in zip(backup, self.model.weights): w.assign(old) @tf.function() def calculate_test_summaries_inner_loop(self, x, y, z, q, m): outputs = self.model(x, training=False) policy = outputs[0] value = outputs[1] policy_loss = self.policy_loss_fn(y, policy) policy_accuracy = self.policy_accuracy_fn(y, policy) policy_entropy = self.policy_entropy_fn(y, policy) policy_ul = self.policy_uniform_loss_fn(y, policy) if self.wdl: value_loss = self.value_loss_fn(self.qMix(z, q), value) mse_loss = self.mse_loss_fn(self.qMix(z, q), value) value_accuracy = self.accuracy_fn(self.qMix(z, q), value) else: value_loss = self.value_loss_fn(self.qMix(z, q), value) mse_loss = self.mse_loss_fn(self.qMix(z, q), value) value_accuracy = tf.constant(0.) if self.moves_left: moves_left = outputs[2] moves_left_loss = self.moves_left_loss_fn(m, moves_left) moves_left_mean_error = self.moves_left_mean_error(m, moves_left) else: moves_left_loss = tf.constant(0.) moves_left_mean_error = tf.constant(0.) metrics = [ policy_loss, value_loss, moves_left_loss, mse_loss / 4, policy_accuracy * 100, value_accuracy * 100, moves_left_mean_error, policy_entropy, policy_ul, ] return metrics @tf.function() def strategy_calculate_test_summaries_inner_loop(self, x, y, z, q, m): metrics = self.strategy.run(self.calculate_test_summaries_inner_loop, args=(x, y, z, q, m)) metrics = [ self.strategy.reduce(tf.distribute.ReduceOp.MEAN, m, axis=None) for m in metrics ] return metrics def calculate_test_summaries(self, test_batches, steps): for metric in self.test_metrics: metric.reset() for _ in range(0, test_batches): x, y, z, q, m = next(self.test_iter) if self.strategy is not None: metrics = self.strategy_calculate_test_summaries_inner_loop( x, y, z, q, m) else: metrics = self.calculate_test_summaries_inner_loop( x, y, z, q, m) for acc, val in zip(self.test_metrics, metrics): acc.accumulate(val) self.net.pb.training_params.learning_rate = self.lr self.net.pb.training_params.mse_loss = self.test_metrics[3].get() self.net.pb.training_params.policy_loss = self.test_metrics[0].get() # TODO store value and value accuracy in pb self.net.pb.training_params.accuracy = self.test_metrics[4].get() with self.test_writer.as_default(): for metric in self.test_metrics: tf.summary.scalar(metric.long_name, metric.get(), step=steps) for w in self.model.weights: tf.summary.histogram(w.name, w, step=steps) self.test_writer.flush() print("step {},".format(steps), end='') for metric in self.test_metrics: print(" {}={:g}{}".format(metric.short_name, metric.get(), metric.suffix), end='') print() def calculate_swa_validations(self, steps): backup = self.read_weights() for (swa, w) in zip(self.swa_weights, self.model.weights): w.assign(swa.read_value()) true_validation_writer, self.validation_writer = self.validation_writer, self.swa_validation_writer print('swa', end=' ') self.calculate_test_validations(steps) self.validation_writer = true_validation_writer for (old, w) in zip(backup, self.model.weights): w.assign(old) def calculate_test_validations(self, steps): for metric in self.test_metrics: metric.reset() for (x, y, z, q, m) in self.validation_dataset: if self.strategy is not None: metrics = self.strategy_calculate_test_summaries_inner_loop( x, y, z, q, m) else: metrics = self.calculate_test_summaries_inner_loop( x, y, z, q, m) for acc, val in zip(self.test_metrics, metrics): acc.accumulate(val) with self.validation_writer.as_default(): for metric in self.test_metrics: tf.summary.scalar(metric.long_name, metric.get(), step=steps) self.validation_writer.flush() print("step {}, validation:".format(steps), end='') for metric in self.test_metrics: print(" {}={:g}{}".format(metric.short_name, metric.get(), metric.suffix), end='') print() @tf.function() def compute_update_ratio(self, before_weights, after_weights, steps): """Compute the ratio of gradient norm to weight norm. Adapted from https://github.com/tensorflow/minigo/blob/c923cd5b11f7d417c9541ad61414bf175a84dc31/dual_net.py#L567 """ deltas = [ after - before for after, before in zip(after_weights, before_weights) ] delta_norms = [tf.math.reduce_euclidean_norm(d) for d in deltas] weight_norms = [ tf.math.reduce_euclidean_norm(w) for w in before_weights ] ratios = [(tensor.name, tf.cond(w != 0., lambda: d / w, lambda: -1.)) for d, w, tensor in zip(delta_norms, weight_norms, self.model.weights) if not 'moving' in tensor.name] for name, ratio in ratios: tf.summary.scalar('update_ratios/' + name, ratio, step=steps) # Filtering is hard, so just push infinities/NaNs to an unreasonably large value. ratios = [ tf.cond(r > 0, lambda: tf.math.log(r) / 2.30258509299, lambda: 200.) for (_, r) in ratios ] tf.summary.histogram('update_ratios_log10', tf.stack(ratios), buckets=1000, step=steps) def update_swa(self): num = self.swa_count.read_value() for (w, swa) in zip(self.model.weights, self.swa_weights): swa.assign(swa.read_value() * (num / (num + 1.)) + w.read_value() * (1. / (num + 1.))) self.swa_count.assign(min(num + 1., self.swa_max_n)) def save_swa_weights(self, filename): backup = self.read_weights() for (swa, w) in zip(self.swa_weights, self.model.weights): w.assign(swa.read_value()) self.save_leelaz_weights(filename) for (old, w) in zip(backup, self.model.weights): w.assign(old) def save_leelaz_weights(self, filename): numpy_weights = [] for weight in self.model.weights: numpy_weights.append([weight.name, weight.numpy()]) self.net.fill_net_v2(numpy_weights) self.net.save_proto(filename) def batch_norm(self, input, name, scale=False): if self.renorm_enabled: clipping = { "rmin": 1.0 / self.renorm_max_r, "rmax": self.renorm_max_r, "dmax": self.renorm_max_d } return tf.keras.layers.BatchNormalization( epsilon=1e-5, axis=1, fused=False, center=True, scale=scale, renorm=True, renorm_clipping=clipping, renorm_momentum=self.renorm_momentum, name=name)(input) else: return tf.keras.layers.BatchNormalization( epsilon=1e-5, axis=1, center=True, scale=scale, virtual_batch_size=self.virtual_batch_size, name=name)(input) def squeeze_excitation(self, inputs, channels, name): assert channels % self.SE_ratio == 0 pooled = tf.keras.layers.GlobalAveragePooling2D( data_format='channels_first')(inputs) squeezed = tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)( tf.keras.layers.Dense(channels // self.SE_ratio, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, name=name + '/se/dense1')(pooled)) excited = tf.keras.layers.Dense(2 * channels, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, name=name + '/se/dense2')(squeezed) return ApplySqueezeExcitation()([inputs, excited]) def conv_block(self, inputs, filter_size, output_channels, name, bn_scale=False): conv = tf.keras.layers.Conv2D(output_channels, filter_size, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first', name=name + '/conv2d')(inputs) return tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)( self.batch_norm(conv, name=name + '/bn', scale=bn_scale)) def residual_block(self, inputs, channels, name): conv1 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first', name=name + '/1/conv2d')(inputs) out1 = tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)( self.batch_norm(conv1, name + '/1/bn', scale=False)) conv2 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first', name=name + '/2/conv2d')(out1) out2 = self.squeeze_excitation(self.batch_norm(conv2, name + '/2/bn', scale=True), channels, name=name + '/se') return tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)( tf.keras.layers.add([inputs, out2])) @staticmethod def split_heads(inputs, batch_size: int, num_heads: int, depth: int): if num_heads < 2: return inputs reshaped = tf.reshape(inputs, (batch_size, 64, num_heads, depth)) # (batch_size, num_heads, 64, depth) return tf.transpose(reshaped, perm=[0, 2, 1, 3]) def scaled_dot_product_attention(self, q, k, v, name: str = None, inputs=None): # 0 h 64 d, 0 h d 64 matmul_qk = tf.matmul(q, k, transpose_b=True) dk = tf.cast(tf.shape(k)[-1], self.model_dtype) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) heads = scaled_attention_logits.shape[1] if self.use_smolgen: smolgen_weights = self.smolgen_weights( inputs, heads, self.smolgen_hidden_channels, self.smolgen_hidden_sz, self.smolgen_gen_sz, name=name + '/smolgen', activation=self.smolgen_activation) scaled_attention_logits = scaled_attention_logits + smolgen_weights attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) return output, scaled_attention_logits # multi-head attention in encoder blocks def mha(self, inputs, emb_size: int, d_model: int, num_heads: int, initializer, name: str): assert d_model % num_heads == 0 depth = d_model // num_heads # query, key, and value vectors for self-attention # inputs b, 64, sz q = tf.keras.layers.Dense(d_model, name=name + '/wq', kernel_initializer='glorot_normal')(inputs) k = tf.keras.layers.Dense(d_model, name=name + '/wk', kernel_initializer='glorot_normal')(inputs) v = tf.keras.layers.Dense(d_model, name=name + '/wv', kernel_initializer=initializer)(inputs) # split q, k and v into smaller vectors of size 'depth' -- one for each head in multi-head attention batch_size = tf.shape(q)[0] q = self.split_heads(q, batch_size, num_heads, depth) k = self.split_heads(k, batch_size, num_heads, depth) v = self.split_heads(v, batch_size, num_heads, depth) scaled_attention, attention_weights = self.scaled_dot_product_attention( q, k, v, name=name, inputs=inputs) if num_heads > 1: scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) scaled_attention = tf.reshape( scaled_attention, (batch_size, -1, d_model)) # concatenate heads # final dense layer output = tf.keras.layers.Dense( emb_size, name=name + "/dense", kernel_initializer=initializer)(scaled_attention) return output, attention_weights # 2-layer dense feed-forward network in encoder blocks def ffn(self, inputs, emb_size: int, dff: int, initializer, name: str): if self.encoder_layers > 0: activation = square_relu if self.square_relu_ffn else tf.keras.activations.get( self.DEFAULT_ACTIVATION) else: activation = "selu" dense1 = tf.keras.layers.Dense(dff, name=name + "/dense1", kernel_initializer=initializer, activation=activation)(inputs) out = tf.keras.layers.Dense(emb_size, name=name + "/dense2", kernel_initializer=initializer)(dense1) return out def encoder_layer(self, inputs, emb_size: int, d_model: int, num_heads: int, dff: int, name: str): initializer = None if self.encoder_layers > 0: # DeepNorm alpha = tf.cast(tf.math.pow(2. * self.encoder_layers, 0.25), self.model_dtype) beta = tf.cast(tf.math.pow(8. * self.encoder_layers, -0.25), self.model_dtype) xavier_norm = tf.keras.initializers.VarianceScaling( scale=beta, mode='fan_avg', distribution='truncated_normal') initializer = xavier_norm else: alpha = 1 initializer = "glorot_normal" # multihead attention attn_output, attn_wts = self.mha(inputs, emb_size, d_model, num_heads, initializer, name=name + "/mha") # dropout for weight regularization attn_output = tf.keras.layers.Dropout(self.dropout_rate, name=name + "/dropout1")(attn_output) # skip connection + layernorm out1 = tf.keras.layers.LayerNormalization( epsilon=1e-6, name=name + "/ln1")(inputs * alpha + attn_output) # feed-forward network ffn_output = self.ffn(out1, emb_size, dff, initializer, name=name + "/ffn") ffn_output = tf.keras.layers.Dropout(self.dropout_rate, name=name + "/dropout2")(ffn_output) out2 = tf.keras.layers.LayerNormalization( epsilon=1e-6, name=name + "/ln2")(out1 * alpha + ffn_output) return out2, attn_wts def smolgen_weights(self, inputs, heads: int, hidden_channels: int, hidden_sz: int, gen_sz: int, name: str, activation='swish'): compressed = tf.keras.layers.Dense(hidden_channels, name=name + '/compress', use_bias=False)(inputs) compressed = tf.reshape(compressed, [-1, 64 * hidden_channels]) hidden = tf.keras.layers.Dense(hidden_sz, name=name + '/hidden1_dense', activation=activation)(compressed) hidden = tf.keras.layers.LayerNormalization(name=name + '/hidden1_ln')(hidden) gen_from = tf.keras.layers.Dense(heads * gen_sz, name=name + '/gen_from', activation=activation)(hidden) gen_from = tf.keras.layers.LayerNormalization(name=name + '/gen_from_ln')(gen_from) gen_from = tf.reshape(gen_from, [-1, heads, gen_sz]) out = self.smol_weight_gen_dense(gen_from) return tf.reshape(out, [-1, heads, 64, 64]) def create_residual_body(self, inputs): flow = self.conv_block(inputs, filter_size=3, output_channels=self.RESIDUAL_FILTERS, name='input', bn_scale=True) for i in range(self.RESIDUAL_BLOCKS): flow = self.residual_block(flow, self.RESIDUAL_FILTERS, name='residual_{}'.format(i + 1)) return flow def create_encoder_body(self, inputs, embedding_size): # Policy head assert self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION # do some input processing if self.use_smolgen: self.smol_weight_gen_dense = tf.keras.layers.Dense( 64 * 64, name='smol_weight_gen', use_bias=False) flow = tf.transpose(inputs, perm=[0, 2, 3, 1]) flow = tf.reshape(flow, [-1, 64, tf.shape(inputs)[1]]) # add positional encoding for each square to the input if self.arc_encoding: self.POS_ENC = apm.make_pos_enc() positional_encoding = tf.broadcast_to( tf.convert_to_tensor(self.POS_ENC, dtype=flow.dtype), [tf.shape(flow)[0], 64, tf.shape(self.POS_ENC)[2]]) flow = tf.concat([flow, positional_encoding], axis=2) # square embedding flow = tf.keras.layers.Dense(embedding_size, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation=self.DEFAULT_ACTIVATION, name='embedding')(flow) # !!! input gate flow = ma_gating(flow, name='embedding') attn_wts = [] for i in range(self.encoder_layers): flow, attn_wts_l = self.encoder_layer(flow, embedding_size, self.encoder_d_model, self.encoder_heads, self.encoder_dff, name='encoder_{}'.format(i + 1)) attn_wts.append(attn_wts_l) return flow, attn_wts def apply_promotion_logits(self, queries, keys, attn_wts): # PAWN PROMOTION: create promotion logits using scalar offsets generated from the promotion-rank keys dk = tf.math.sqrt(tf.cast(tf.shape(keys)[-1], self.model_dtype)) # constant for scaling promotion_keys = keys[:, -8:, :] # queen, rook, bishop, knight order promotion_offsets = tf.keras.layers.Dense( 4, kernel_initializer='glorot_normal', name='policy/attention/ppo', use_bias=False)(promotion_keys) promotion_offsets = tf.transpose(promotion_offsets, perm=[0, 2, 1]) * dk # Bx4x8 # knight offset is added to the other three promotion_offsets = promotion_offsets[:, : 3, :] + promotion_offsets[:, 3:4, :] # POLICY SELF-ATTENTION: self-attention weights are interpreted as from->to policy matmul_qk = tf.matmul( queries, keys, transpose_b=True) # Bx64x64 (from 64 queries, 64 keys) # q, r, and b promotions are offset from the default promotion logit (knight) n_promo_logits = matmul_qk[:, -16:-8, -8:] # default traversals from penultimate rank to promotion rank q_promo_logits = tf.expand_dims(n_promo_logits + promotion_offsets[:, 0:1, :], axis=3) # Bx8x8x1 r_promo_logits = tf.expand_dims(n_promo_logits + promotion_offsets[:, 1:2, :], axis=3) b_promo_logits = tf.expand_dims(n_promo_logits + promotion_offsets[:, 2:3, :], axis=3) promotion_logits = tf.concat( [q_promo_logits, r_promo_logits, b_promo_logits], axis=3) # Bx8x8x3 promotion_logits = tf.reshape( promotion_logits, [-1, 8, 24]) # logits now alternate a7a8q,a7a8r,a7a8b,..., # scale the logits by dividing them by sqrt(d_model) to stabilize gradients promotion_logits = promotion_logits / dk # Bx8x24 (8 from-squares, 3x8 promotions) policy_attn_logits = matmul_qk / dk # Bx64x64 (64 from-squares, 64 to-squares) attn_wts.append(promotion_logits) attn_wts.append(policy_attn_logits) # APPLY POLICY MAP: output becomes Bx1856 h_fc1 = ApplyAttentionPolicyMap()(policy_attn_logits, promotion_logits) return h_fc1 def construct_net(self, inputs, name=''): if self.encoder_layers > 0: flow, attn_wts = self.create_encoder_body(inputs, self.embedding_size) else: flow = self.create_residual_body(inputs) # Policy head if self.POLICY_HEAD == pb.NetworkFormat.POLICY_CONVOLUTION: conv_pol = self.conv_block(flow, filter_size=3, output_channels=self.RESIDUAL_FILTERS, name='policy1') conv_pol2 = tf.keras.layers.Conv2D( 80, 3, use_bias=True, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, data_format='channels_first', name='policy')(conv_pol) h_fc1 = ApplyPolicyMap()(conv_pol2) elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_CLASSICAL: conv_pol = self.conv_block(flow, filter_size=1, output_channels=self.policy_channels, name='policy') h_conv_pol_flat = tf.keras.layers.Flatten()(conv_pol) h_fc1 = tf.keras.layers.Dense(1858, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, name='policy/dense')(h_conv_pol_flat) elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION: if self.encoder_layers == 0: attn_wts = [] if self.RESIDUAL_BLOCKS > 0: # transpose and reshape tokens = tf.transpose(flow, perm=[0, 2, 3, 1]) tokens = tf.reshape(tokens, [-1, 64, self.RESIDUAL_FILTERS]) embed_activation = 'selu' else: tokens = flow embed_activation = self.DEFAULT_ACTIVATION # SQUARE EMBEDDING: found to increase attention head performance tokens = tf.keras.layers.Dense(self.pol_embedding_size, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation=embed_activation, name='policy/embedding')(tokens) if self.RESIDUAL_BLOCKS > 0: # ENCODER LAYERS: intermediate layers of self-attention with residual connections for i in range(self.pol_encoder_layers): tokens, attn_wts_l = self.encoder_layer( tokens, self.pol_embedding_size, self.pol_encoder_d_model, self.pol_encoder_heads, self.pol_encoder_dff, name='policy/enc_layer_{}'.format(i + 1)) attn_wts.append(attn_wts_l) # create queries and keys for policy self-attention queries = tf.keras.layers.Dense(self.policy_d_model, kernel_initializer='glorot_normal', name='policy/attention/wq')(tokens) keys = tf.keras.layers.Dense(self.policy_d_model, kernel_initializer='glorot_normal', name='policy/attention/wk')(tokens) h_fc1 = self.apply_promotion_logits(queries, keys, attn_wts) else: raise ValueError("Unknown policy head type {}".format( self.POLICY_HEAD)) # Value head if self.encoder_layers > 0: conv_val = tf.keras.layers.Dense( self.val_embedding_size, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation=self.DEFAULT_ACTIVATION, name='value/embedding')(flow) else: conv_val = self.conv_block(flow, filter_size=1, output_channels=32, name='value') h_conv_val_flat = tf.keras.layers.Flatten()(conv_val) h_fc2 = tf.keras.layers.Dense(128, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation=self.DEFAULT_ACTIVATION, name='value/dense1')(h_conv_val_flat) if self.wdl: h_fc3 = tf.keras.layers.Dense(3, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, name='value/dense2')(h_fc2) else: h_fc3 = tf.keras.layers.Dense(1, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation='tanh', name='value/dense2')(h_fc2) # Moves left head if self.moves_left: if self.encoder_layers > 0: conv_mov = tf.keras.layers.Dense( self.mov_embedding_size, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation=self.DEFAULT_ACTIVATION, name='moves_left/embedding')(flow) else: conv_mov = self.conv_block(flow, filter_size=1, output_channels=8, name='moves_left') h_conv_mov_flat = tf.keras.layers.Flatten()(conv_mov) h_fc4 = tf.keras.layers.Dense( 128, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation=self.DEFAULT_ACTIVATION, name='moves_left/dense1')(h_conv_mov_flat) h_fc5 = tf.keras.layers.Dense(1, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation='relu', name='moves_left/dense2')(h_fc4) else: h_fc5 = None # attention weights added as optional output for analysis -- ignored by backend if self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION: if self.moves_left: outputs = [h_fc1, h_fc3, h_fc5, attn_wts] else: outputs = [h_fc1, h_fc3, attn_wts] elif self.moves_left: outputs = [h_fc1, h_fc3, h_fc5] else: outputs = [h_fc1, h_fc3] return outputs ================================================ FILE: tf/train.py ================================================ #!/usr/bin/env python3 # # This file is part of Leela Zero. # Copyright (C) 2017 Gian-Carlo Pascutto # # Leela Zero is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Leela Zero is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Leela Zero. If not, see . import argparse import os import yaml import sys import glob import gzip import random import multiprocessing as mp from chunkparser import ChunkParser SKIP = 32 def get_chunks(data_prefix): return glob.glob(data_prefix + "*.gz") def get_all_chunks(path): if isinstance(path, list): print("getting chunks for", path) chunks = [] for i in path: chunks += get_all_chunks(i) return chunks chunks = [] for d in glob.glob(path): chunks += get_chunks(d) print("got", len(chunks), "chunks for", path) return chunks def get_latest_chunks(path, num_chunks, allow_less, sort_key_fn): chunks = get_all_chunks(path) if len(chunks) < num_chunks: if allow_less: print("sorting {} chunks...".format(len(chunks)), end='', flush=True) chunks.sort(key=sort_key_fn, reverse=True) print("[done]") print("{} - {}".format(os.path.basename(chunks[-1]), os.path.basename(chunks[0]))) random.shuffle(chunks) return chunks else: print("Not enough chunks {}".format(len(chunks))) sys.exit(1) print("sorting {} chunks...".format(len(chunks)), end='', flush=True) chunks.sort(key=sort_key_fn, reverse=True) print("[done]") chunks = chunks[:num_chunks] print("{} - {}".format(os.path.basename(chunks[-1]), os.path.basename(chunks[0]))) random.shuffle(chunks) return chunks def identity_function(name): return name def game_number_for_name(name): num_str = os.path.basename(name).upper().strip( "ABCDEFGHIJKLMNOPQRSTUVWXYZ_-.") return int(num_str) def get_input_mode(cfg): import proto.net_pb2 as pb input_mode = cfg['model'].get('input_type', 'classic') if input_mode == "classic": return pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE elif input_mode == "frc_castling": return pb.NetworkFormat.INPUT_112_WITH_CASTLING_PLANE elif input_mode == "canonical": return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION elif input_mode == "canonical_100": return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES elif input_mode == "canonical_armageddon": return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON elif input_mode == "canonical_v2": return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2 elif input_mode == "canonical_v2_armageddon": return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON else: raise ValueError("Unknown input mode format: {}".format(input_mode)) def main(cmd): cfg = yaml.safe_load(cmd.cfg.read()) print(yaml.dump(cfg, default_flow_style=False)) num_chunks = cfg['dataset']['num_chunks'] allow_less = cfg['dataset'].get('allow_less_chunks', False) train_ratio = cfg['dataset']['train_ratio'] num_train = int(num_chunks * train_ratio) num_test = num_chunks - num_train sort_type = cfg['dataset'].get('sort_type', 'mtime') if sort_type == 'mtime': sort_key_fn = os.path.getmtime elif sort_type == 'number': sort_key_fn = game_number_for_name elif sort_type == 'name': sort_key_fn = identity_function else: raise ValueError('Unknown dataset sort_type: {}'.format(sort_type)) if 'input_test' in cfg['dataset']: train_chunks = get_latest_chunks(cfg['dataset']['input_train'], num_train, allow_less, sort_key_fn) test_chunks = get_latest_chunks(cfg['dataset']['input_test'], num_test, allow_less, sort_key_fn) else: chunks = get_latest_chunks(cfg['dataset']['input'], num_chunks, allow_less, sort_key_fn) if allow_less: num_train = int(len(chunks) * train_ratio) num_test = len(chunks) - num_train train_chunks = chunks[:num_train] test_chunks = chunks[num_train:] shuffle_size = cfg['training']['shuffle_size'] total_batch_size = cfg['training']['batch_size'] batch_splits = cfg['training'].get('num_batch_splits', 1) train_workers = cfg['dataset'].get('train_workers', None) test_workers = cfg['dataset'].get('test_workers', None) if total_batch_size % batch_splits != 0: raise ValueError('num_batch_splits must divide batch_size evenly') split_batch_size = total_batch_size // batch_splits diff_focus_min = cfg['training'].get('diff_focus_min', 1) diff_focus_slope = cfg['training'].get('diff_focus_slope', 0) diff_focus_q_weight = cfg['training'].get('diff_focus_q_weight', 6.0) diff_focus_pol_scale = cfg['training'].get('diff_focus_pol_scale', 3.5) root_dir = os.path.join(cfg['training']['path'], cfg['name']) if not os.path.exists(root_dir): os.makedirs(root_dir) train_parser = ChunkParser(train_chunks, get_input_mode(cfg), shuffle_size=shuffle_size, sample=SKIP, batch_size=split_batch_size, diff_focus_min=diff_focus_min, diff_focus_slope=diff_focus_slope, diff_focus_q_weight=diff_focus_q_weight, diff_focus_pol_scale=diff_focus_pol_scale, workers=train_workers) test_shuffle_size = int(shuffle_size * (1.0 - train_ratio)) # no diff focus for test_parser test_parser = ChunkParser(test_chunks, get_input_mode(cfg), shuffle_size=test_shuffle_size, sample=SKIP, batch_size=split_batch_size, workers=test_workers) if 'input_validation' in cfg['dataset']: valid_chunks = get_all_chunks(cfg['dataset']['input_validation']) validation_parser = ChunkParser(valid_chunks, get_input_mode(cfg), sample=1, batch_size=split_batch_size, workers=0) import tensorflow as tf from chunkparsefunc import parse_function from tfprocess import TFProcess tfprocess = TFProcess(cfg) train_dataset = tf.data.Dataset.from_generator( train_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string, tf.string)) train_dataset = train_dataset.map(parse_function) test_dataset = tf.data.Dataset.from_generator( test_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string, tf.string)) test_dataset = test_dataset.map(parse_function) validation_dataset = None if 'input_validation' in cfg['dataset']: validation_dataset = tf.data.Dataset.from_generator( validation_parser.sequential, output_types=(tf.string, tf.string, tf.string, tf.string, tf.string)) validation_dataset = validation_dataset.map(parse_function) if tfprocess.strategy is None: #Mirrored strategy appends prefetch itself with a value depending on number of replicas train_dataset = train_dataset.prefetch(4) test_dataset = test_dataset.prefetch(4) if validation_dataset is not None: validation_dataset = validation_dataset.prefetch(4) else: options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF train_dataset = train_dataset.with_options(options) test_dataset = test_dataset.with_options(options) if validation_dataset is not None: validation_dataset = validation_dataset.with_options(options) tfprocess.init(train_dataset, test_dataset, validation_dataset) tfprocess.restore() # If number of test positions is not given # sweeps through all test chunks statistically # Assumes average of 10 samples per test game. # For simplicity, testing can use the split batch size instead of total batch size. # This does not affect results, because test results are simple averages that are independent of batch size. num_evals = cfg['training'].get('num_test_positions', len(test_chunks) * 10) num_evals = max(1, num_evals // split_batch_size) print("Using {} evaluation batches".format(num_evals)) tfprocess.total_batch_size = total_batch_size tfprocess.process_loop(total_batch_size, num_evals, batch_splits=batch_splits) if cmd.output is not None: if cfg['training'].get('swa_output', False): tfprocess.save_swa_weights(cmd.output) else: tfprocess.save_leelaz_weights(cmd.output) train_parser.shutdown() test_parser.shutdown() if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Tensorflow pipeline for training Leela Chess.') argparser.add_argument('--cfg', type=argparse.FileType('r'), help='yaml configuration with training parameters') argparser.add_argument('--output', type=str, help='file to store weights in') #mp.set_start_method('spawn') main(argparser.parse_args()) mp.freeze_support() ================================================ FILE: tf/update_steps.py ================================================ #!/usr/bin/env python3 import argparse import os import yaml import sys import tensorflow as tf from tfprocess import TFProcess START_FROM = 0 def main(cmd): cfg = yaml.safe_load(cmd.cfg.read()) print(yaml.dump(cfg, default_flow_style=False)) root_dir = os.path.join(cfg['training']['path'], cfg['name']) if not os.path.exists(root_dir): os.makedirs(root_dir) tfprocess = TFProcess(cfg) tfprocess.init_net() tfprocess.restore() START_FROM = cmd.start tfprocess.global_step.assign(START_FROM) tfprocess.manager.save(checkpoint_number=START_FROM) if __name__ == "__main__": argparser = argparse.ArgumentParser(description=\ 'Convert current checkpoint to new step count.') argparser.add_argument('--cfg', type=argparse.FileType('r'), help='yaml configuration with training parameters') argparser.add_argument('--start', type=int, default=0, help='Offset to set global_step to.') main(argparser.parse_args())