Repository: inference-labs-inc/dsperse Branch: main Commit: a71f618e5cd5 Files: 72 Total size: 939.0 KB Directory structure: gitextract_ogh8afab/ ├── .cargo/ │ ├── audit.toml │ └── config.toml ├── .github/ │ └── workflows/ │ ├── integration_tests.yml │ └── publish.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── crates/ │ └── dsperse/ │ ├── Cargo.toml │ ├── benches/ │ │ └── serialization.rs │ ├── build.rs │ ├── proto/ │ │ └── onnx.proto │ ├── src/ │ │ ├── backend/ │ │ │ ├── jstprove.rs │ │ │ ├── mod.rs │ │ │ ├── onnx.rs │ │ │ └── traits.rs │ │ ├── cli/ │ │ │ └── mod.rs │ │ ├── converter.rs │ │ ├── error.rs │ │ ├── lib.rs │ │ ├── main.rs │ │ ├── pipeline/ │ │ │ ├── channel_split.rs │ │ │ ├── combined.rs │ │ │ ├── compiler.rs │ │ │ ├── dim_split.rs │ │ │ ├── incremental.rs │ │ │ ├── mod.rs │ │ │ ├── packager.rs │ │ │ ├── prover.rs │ │ │ ├── publisher.rs │ │ │ ├── runner.rs │ │ │ ├── slice_cache.rs │ │ │ ├── stage.rs │ │ │ ├── strategy.rs │ │ │ ├── tensor_store.rs │ │ │ ├── tile_executor.rs │ │ │ ├── tiled.rs │ │ │ └── verifier.rs │ │ ├── python.rs │ │ ├── schema/ │ │ │ ├── execution.rs │ │ │ ├── metadata.rs │ │ │ ├── mod.rs │ │ │ └── tiling.rs │ │ ├── slicer/ │ │ │ ├── analyzer.rs │ │ │ ├── autotiler.rs │ │ │ ├── combiner.rs │ │ │ ├── layernorm_fuse.rs │ │ │ ├── materializer.rs │ │ │ ├── mod.rs │ │ │ ├── onnx_fold.rs │ │ │ ├── onnx_proto.rs │ │ │ ├── onnx_shapes.rs │ │ │ ├── onnx_slicer.rs │ │ │ ├── self_div_rewrite.rs │ │ │ └── trace.rs │ │ ├── utils/ │ │ │ ├── io.rs │ │ │ ├── limits.rs │ │ │ ├── metadata.rs │ │ │ ├── mod.rs │ │ │ └── paths.rs │ │ └── version.rs │ └── tests/ │ ├── integration_slice.rs │ ├── schema_roundtrip.rs │ └── sn2_contract.rs ├── deny.toml ├── docs/ │ ├── JSTPROVE_BACKEND.md │ ├── overview.md │ └── uv_packaging.md ├── pyproject.toml ├── python/ │ └── dsperse/ │ ├── __init__.py │ └── cli.py └── rust-toolchain.toml ================================================ FILE CONTENTS ================================================ ================================================ FILE: .cargo/audit.toml ================================================ [advisories] ignore = [ "RUSTSEC-2026-0009", # time crate DoS via RFC 2822 parsing — transitive dep, not user-facing ] ================================================ FILE: .cargo/config.toml ================================================ [net] git-fetch-with-cli = true ================================================ FILE: .github/workflows/integration_tests.yml ================================================ name: Integration Tests on: push: branches: - main pull_request: branches: - main concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true permissions: contents: read jobs: fmt: name: Rustfmt runs-on: ubuntu-latest timeout-minutes: 10 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Rust toolchain run: rustup show - run: cargo fmt --check test: name: Rust Tests runs-on: ubuntu-latest timeout-minutes: 45 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true - name: Install Rust toolchain run: rustup show - name: Install protoc run: sudo apt-get update && sudo apt-get install -y protobuf-compiler - uses: Swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32 # v2.9.1 - name: Test run: cargo test --locked --manifest-path crates/dsperse/Cargo.toml - name: Test (with python feature) run: cargo test --locked --manifest-path crates/dsperse/Cargo.toml --features python - name: Clippy run: cargo clippy --locked --manifest-path crates/dsperse/Cargo.toml --all-targets --features python -- -D warnings audit: name: Security audit runs-on: ubuntu-latest timeout-minutes: 10 permissions: contents: read checks: write steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - run: rm -f rust-toolchain.toml && rustup install stable && rustup default stable - uses: rustsec/audit-check@69366f33c96575abad1ee0dba8212993eecbe998 # v2.0.0 with: token: ${{ secrets.GITHUB_TOKEN }} deny: name: Cargo deny runs-on: ubuntu-latest timeout-minutes: 10 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: EmbarkStudios/cargo-deny-action@3fd3802e88374d3fe9159b834c7714ec57d6c979 # v2.0.15 with: command: check bans sources ================================================ FILE: .github/workflows/publish.yml ================================================ name: Build and Publish to PyPI on: push: tags: - "v*" pull_request: workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true env: UV_VERSION: "0.10.8" MATURIN_VERSION: "1.12.6" jobs: build-linux: if: >- github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'test-build') runs-on: ubuntu-latest timeout-minutes: 60 container: image: quay.io/pypa/manylinux_2_28_x86_64 permissions: contents: read steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Set up Python run: echo "/opt/python/cp312-cp312/bin" >> $GITHUB_PATH - name: Install uv uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7 with: version: ${{ env.UV_VERSION }} - name: Install system dependencies run: | dnf install -y protobuf-compiler protobuf-devel pkgconf-pkg-config perl-IPC-Cmd perl-Time-Piece clang-devel - name: Install Rust run: | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly-2025-03-27 echo "$HOME/.cargo/bin" >> $GITHUB_PATH - name: Extract version id: get_version shell: bash run: | if [[ "$GITHUB_REF" == refs/tags/v* ]]; then VERSION=${GITHUB_REF#refs/tags/v} else VERSION=$(grep -m1 '^version' pyproject.toml | sed 's/.*"\(.*\)".*/\1/') fi echo "version=$VERSION" >> $GITHUB_OUTPUT - name: Update versions run: | sed -i '0,/^version = ".*"/{s/^version = ".*"/version = "${{ steps.get_version.outputs.version }}"/}' pyproject.toml sed -i '0,/^version = ".*"/{s/^version = ".*"/version = "${{ steps.get_version.outputs.version }}"/}' crates/dsperse/Cargo.toml - name: Cache Rust dependencies uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 with: path: | ~/.cargo/registry ~/.cargo/git target key: manylinux-2-28-cargo-${{ hashFiles('**/Cargo.lock') }} restore-keys: | manylinux-2-28-cargo- - name: Build wheel run: uvx maturin==${{ env.MATURIN_VERSION }} build --release --manylinux 2_28 -i /opt/python/cp312-cp312/bin/python3 - name: Test wheel installation run: | uv pip install --system --python python3 target/wheels/*.whl python3 -c "from dsperse import slice_model; print('PyO3 bindings OK')" - name: Upload wheel artifact uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: wheel-ubuntu-x86_64 path: ./target/wheels/*.whl build-macos: if: >- github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'test-build') runs-on: macos-latest timeout-minutes: 60 permissions: contents: read steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Set up Python uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.12" - name: Install uv uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7 with: version: ${{ env.UV_VERSION }} - name: Install system dependencies run: brew install protobuf llvm - name: Extract version id: get_version shell: bash run: | if [[ "$GITHUB_REF" == refs/tags/v* ]]; then VERSION=${GITHUB_REF#refs/tags/v} else VERSION=$(grep -m1 '^version' pyproject.toml | sed 's/.*"\(.*\)".*/\1/') fi echo "version=$VERSION" >> $GITHUB_OUTPUT - name: Update versions run: | sed -i '' '1,/^version = /{s/^version = ".*"/version = "${{ steps.get_version.outputs.version }}"/;}' pyproject.toml sed -i '' '1,/^version = /{s/^version = ".*"/version = "${{ steps.get_version.outputs.version }}"/;}' crates/dsperse/Cargo.toml - name: Install Rust uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable 2026-02-13 with: toolchain: nightly-2025-03-27 - name: Install Rust target run: rustup target add aarch64-apple-darwin - name: Cache Rust dependencies uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 with: path: | ~/.cargo/registry ~/.cargo/git target key: macos-cargo-${{ hashFiles('**/Cargo.lock') }} restore-keys: | macos-cargo- - name: Build wheel run: uvx maturin==${{ env.MATURIN_VERSION }} build --release --target aarch64-apple-darwin env: MACOSX_DEPLOYMENT_TARGET: "11.0" - name: Test wheel installation run: | uv pip install --system --python python3 target/wheels/*.whl python3 -c "from dsperse import slice_model; print('PyO3 bindings OK')" - name: Upload wheel artifact uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: wheel-macos-aarch64 path: ./target/wheels/*.whl publish: if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') needs: [build-linux, build-macos] runs-on: ubuntu-latest timeout-minutes: 15 permissions: contents: write id-token: write steps: - name: Download all wheels uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: wheel-* merge-multiple: true path: ./dist - name: Extract version from tag id: get_version shell: bash run: | VERSION=${GITHUB_REF#refs/tags/v} echo "version=$VERSION" >> $GITHUB_OUTPUT - name: Create GitHub Release with wheels uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2.6.1 with: name: Release ${{ steps.get_version.outputs.version }} files: ./dist/*.whl - name: Install uv uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7 with: version: ${{ env.UV_VERSION }} - name: Publish to PyPI run: uv publish ./dist/* ================================================ FILE: .gitignore ================================================ # macOS system files .DS_Store .DS_* tests/models/run # macOS metadata ._* # Python cache __pycache__/ *.py[cod] # Environment files .env .venv/ env/ venv/ # IDE/editor folders .vscode/ .idea/ # Log files *.log # Byte-compiled *.pyo # Jupyter Notebook checkpoints .ipynb_checkpoints/ # Python egg artifacts *.egg *.egg-info/ dist/ build/ eggs/ parts/ bin/ var/ sdist/ develop-eggs/ .installed.cfg # ignore the models we test with */models/*/slices */src/models/*/slices/ */models/*/model_metadata.json */src/models/*/model_metadata.json */models/*/analysis/model_metadata.json */src/models/*/analysis/model_metadata.json */models/*/run */src/models/*/run/ */models/*/input.json */src/models/*/input.json */models/*/*.onnx */src/models/*/*.onnx */models/*/*.dsperse */src/models/*/*.dsperse */models/*/*.data */src/models/*/*.data # Local virtual envs python.venv/ .venv/ venv/ # Slice output directories pitch-sliced/ *-sliced/ # Test output tests/models/output/ /target /crates/*/target ================================================ FILE: Cargo.toml ================================================ [workspace] members = ["crates/dsperse"] resolver = "2" [workspace.package] edition = "2024" [workspace.dependencies] serde = { version = "1", features = ["derive"] } rmpv = { version = "1", features = ["with-serde"] } rmp-serde = "1" thiserror = "2" clap = { version = "4", features = ["derive", "env"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } rayon = "1" ndarray = { version = "0.17", features = ["serde"] } tract-onnx = { git = "https://github.com/inference-labs-inc/tract.git", rev = "3cfae7f7" } uuid = { version = "1", features = ["v4"] } sha2 = "0.10" tempfile = "3" prost = "0.13" pyo3 = { version = "0.24" } jstprove_circuits = { git = "https://github.com/inference-labs-inc/JSTprove.git", rev = "87a1859f3487cf0fb9a463dbfd713b1df4827afc" } jstprove_io = { git = "https://github.com/inference-labs-inc/JSTprove.git", rev = "87a1859f3487cf0fb9a463dbfd713b1df4827afc", package = "jstprove-io" } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } tokio = { version = "1", features = ["rt", "macros"] } ================================================ FILE: LICENSE ================================================ Copyright (c) 2025 Inference Labs Inc. Source Access Grant You may access, view, study, and modify the source code of this software. Redistribution Conditions You may redistribute this software in source or modified form provided that: a) You retain this license document and all copyright notices b) Any modified files carry prominent notices stating you changed them c) You do not misrepresent the origin of the software Usage Restriction NO USE RIGHTS ARE GRANTED BY THIS LICENSE. Any operational use including but not limited to: - Execution of the software - Integration with other systems - Deployment in any environment - Commercial or production utilization requires express written permission from the IP Owner. Intellectual Property Reservation All rights not expressly granted herein are reserved by the IP Owner. For usage permissions, contact: legal@inferencelabs.com Disclaimer THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND. THE IP OWNER SHALL NOT BE LIABLE FOR ANY DAMAGES ARISING FROM ACCESS OR DISTRIBUTION. License Propagation Any distribution of this software or derivatives must be under this same license agreement. ================================================ FILE: README.md ================================================ # DSperse: Community Edition [![GitHub](https://img.shields.io/badge/GitHub-Repository-blue?style=flat-square&logo=github)](https://github.com/inference-labs-inc/dsperse) [![Discord](https://img.shields.io/badge/Discord-Join%20Community-7289DA?style=flat-square&logo=discord)](https://discord.gg/GBxBCWJs) [![Telegram](https://img.shields.io/badge/Telegram-Join%20Channel-0088cc?style=flat-square&logo=telegram)](https://t.me/inference_labs) [![Twitter](https://img.shields.io/badge/Twitter-Follow%20Us-1DA1F2?style=flat-square&logo=twitter)](https://x.com/inference_labs) [![Website](https://img.shields.io/badge/Website-Visit%20Us-ff7139?style=flat-square&logo=firefox-browser)](https://inferencelabs.com) [![Whitepaper](https://img.shields.io/badge/Whitepaper-Read-lightgrey?style=flat-square&logo=read-the-docs)](http://arxiv.org/abs/2508.06972) DSperse is a proving-system-agnostic intelligent slicer for verifiable AI. It decomposes ONNX neural network models into circuit-compatible segments and orchestrates compilation, inference, proving, and verification across pluggable ZK backends. ## Features - **Model Slicing**: Split neural network models into individual layers or custom segments - **ONNX Support**: Slice and orchestrate ONNX models - **Layered Inference**: Run inference on sliced models, chaining the output of each segment - **Zero-Knowledge Proofs**: Generate and verify proofs for model execution via JSTprove - **Tiling and Channel Splitting**: Automatically decompose large convolutions for circuit-compatible execution - **Proof System Agnostic**: Pluggable backend architecture supporting Expander and Remainder proof systems ## Documentation - [Overview](docs/overview.md): High-level overview of the project, its goals, and features - [JSTprove Backend](docs/JSTPROVE_BACKEND.md): JSTprove integration and usage ## Installation ### From PyPI (includes CLI) ```bash pip install dsperse ``` This installs both the `dsperse` CLI command and the Python library bindings. No additional dependencies required — everything is compiled into a single native extension. ### From source (Rust binary) ```bash cargo install --path crates/dsperse ``` ### As a Rust library ```toml [dependencies] dsperse = { git = "https://github.com/inference-labs-inc/dsperse.git" } ``` ## CLI Usage DSperse provides six subcommands that form a complete pipeline: | Command | Description | |---------|-------------| | `slice` | Split an ONNX model into segments | | `compile` | Compile slices into ZK circuits | | `run` | Execute chained inference across slices (`--weights` to inject consumer ONNX) | | `prove` | Generate ZK proofs for a completed run | | `verify` | Verify ZK proofs | | `full-run` | Execute compile, run, prove, verify in sequence (supports `--weights`) | ### Quickstart ```bash dsperse slice --model-dir models/net dsperse compile --model-dir models/net --parallel 4 dsperse run --model-dir models/net --input-file models/net/input.json dsperse prove --model-dir models/net --run-dir models/net/run/run_* dsperse verify --model-dir models/net --run-dir models/net/run/run_* ``` Or run the entire pipeline at once: ```bash dsperse full-run --model-dir models/net --input-file models/net/input.json ``` To inject consumer weights from a fine-tuned ONNX model (same architecture, different weights): ```bash dsperse run --model-dir models/net --input-file models/net/input.json --weights path/to/consumer.onnx dsperse full-run --model-dir models/net --input-file models/net/input.json --weights path/to/consumer.onnx ``` ## Python Library Usage ```python import dsperse metadata_json = dsperse.slice_model("models/net/model.onnx", output_dir="models/net/slices") dsperse.compile_slices("models/net/slices", parallel=4) run_json = dsperse.run_inference("models/net/slices", "models/net/input.json", "models/net/run") proof_json = dsperse.prove_run("models/net/run", "models/net/slices") verify_json = dsperse.verify_run("models/net/run", "models/net/slices") ``` To inject consumer weights at inference time, pass `weights_onnx` (path to a fine-tuned ONNX with the same architecture): ```python run_json = dsperse.run_inference( "models/net/slices", "models/net/input.json", "models/net/run", weights_onnx="path/to/consumer.onnx", ) ``` `slice_model`, `run_inference`, `prove_run`, and `verify_run` return JSON strings parseable with `json.loads()`. `compile_slices` returns `None`. ## Project Structure ```text crates/dsperse/ src/ cli/ CLI argument parsing and command dispatch slicer/ ONNX model analysis, slicing, autotiling, channel splitting pipeline/ Compilation, inference, proving, verification orchestration backend/ JSTprove backend integration schema/ Metadata and execution result types (serde) converter.rs Prepares JSTprove artifacts from ONNX files utils/ I/O helpers and path resolution tests/ Unit and integration tests python/ Thin Python wrapper for PyO3 bindings ``` ## Contributing Contributions are welcome. Please open issues and PRs on GitHub. ## License See the [LICENSE](LICENSE) file for details. ================================================ FILE: crates/dsperse/Cargo.toml ================================================ [package] name = "dsperse" version = "0.0.0" edition.workspace = true [features] default = [] python = ["dep:pyo3", "pyo3/extension-module"] [dependencies] serde.workspace = true rmpv.workspace = true rmp-serde.workspace = true thiserror.workspace = true clap.workspace = true tracing.workspace = true tracing-subscriber.workspace = true rayon.workspace = true ndarray.workspace = true tract-onnx.workspace = true uuid.workspace = true sha2.workspace = true tempfile.workspace = true prost.workspace = true pyo3 = { workspace = true, optional = true } serde_json = "1" zip = { version = "2", default-features = false, features = ["deflate"] } walkdir = "2" jstprove_circuits.workspace = true jstprove_io.workspace = true reqwest.workspace = true tokio.workspace = true [target.'cfg(unix)'.dependencies] libc = "0.2" [build-dependencies] prost-build = "0.13" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } [[bench]] name = "serialization" harness = false [lib] name = "dsperse" crate-type = ["cdylib", "lib"] ================================================ FILE: crates/dsperse/benches/serialization.rs ================================================ use std::collections::HashMap; use criterion::{Criterion, black_box, criterion_group, criterion_main}; use dsperse::schema::execution::{ ExecutionChain, ExecutionInfo, ExecutionMethod, ExecutionNode, ExecutionResultEntry, RunMetadata, SliceResult, TileResult, }; use dsperse::schema::metadata::{ BackendKind, Compilation, Dependencies, ModelMetadata, RunSliceMetadata, SliceMetadata, SliceShapeWrapper, TensorShape, }; use serde::{Deserialize, Serialize}; fn make_slice_metadata(index: usize) -> SliceMetadata { SliceMetadata { index, filename: format!("slice_{index}.onnx"), path: format!("/tmp/slices/slice_{index}/payload/slice_{index}.onnx"), relative_path: format!("slice_{index}/payload/slice_{index}.onnx"), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1, 3, 224, 224]], output: vec![vec![1, 64, 112, 112]], }, }, dependencies: Dependencies { input: vec![format!("input_{index}")], output: vec![format!("output_{index}")], filtered_inputs: vec![format!("input_{index}")], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: Some(format!("slice_{index}/metadata.msgpack")), slice_metadata_relative_path: Some(format!("slice_{index}/metadata.msgpack")), } } fn make_model_metadata(num_slices: usize) -> ModelMetadata { let slices: Vec = (0..num_slices).map(make_slice_metadata).collect(); let slice_points: Vec = (0..=num_slices).collect(); ModelMetadata { original_model: "/tmp/model.onnx".into(), model_type: "ONNX".into(), input_shape: vec![vec![1, 3, 224, 224]], output_shapes: vec![vec![1, 1000]], output_names: vec!["output".into()], slice_points, slices, dsperse_version: Some("0.0.0".into()), dsperse_rev: Some("abc1234".into()), jstprove_version: Some("0.1.0".into()), jstprove_rev: Some("def5678".into()), traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], } } fn make_run_metadata(num_slices: usize) -> RunMetadata { let mut slices = HashMap::new(); let mut nodes = HashMap::new(); let mut execution_results = Vec::new(); for i in 0..num_slices { let slice_id = format!("slice_{i}"); slices.insert( slice_id.clone(), RunSliceMetadata { path: format!("slice_{i}/payload/slice_{i}.onnx"), input_shape: vec![vec![1, 3, 224, 224]], output_shape: vec![vec![1, 64, 112, 112]], dependencies: Dependencies { input: vec![format!("input_{i}")], output: vec![format!("output_{i}")], filtered_inputs: vec![format!("input_{i}")], }, tiling: None, channel_split: None, dim_split: None, backend: BackendKind::Jstprove, jstprove_circuit_path: Some(format!("slice_{i}/jstprove/circuit.bin")), jstprove_settings_path: None, }, ); nodes.insert( slice_id.clone(), ExecutionNode { slice_id: slice_id.clone(), primary: Some("jstprove_gen_witness".into()), fallbacks: vec!["onnx_only".into()], use_circuit: true, next: if i + 1 < num_slices { Some(format!("slice_{}", i + 1)) } else { None }, circuit_path: Some(format!("slice_{i}/jstprove/circuit.bin")), onnx_path: Some(format!("slice_{i}/payload/slice_{i}.onnx")), backend: BackendKind::Jstprove, }, ); execution_results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(ExecutionInfo { method: ExecutionMethod::JstproveGenWitness, success: true, error: None, witness_file: Some(format!("runs/run_0/{slice_id}/witness.bin")), tile_exec_infos: vec![TileResult { tile_idx: 0, success: true, error: None, method: Some(ExecutionMethod::JstproveGenWitness), time_sec: 1.23, proof_path: None, }], }), proof_execution: Some(SliceResult { slice_id: slice_id.clone(), success: true, method: Some(ExecutionMethod::JstproveProve), error: None, proof_path: Some(format!("runs/run_0/{slice_id}/proof.bin")), time_sec: 45.67, tiles: Vec::new(), }), verification_execution: None, }); } RunMetadata { slices, execution_chain: ExecutionChain { head: Some("slice_0".into()), nodes, fallback_map: HashMap::new(), execution_results, jstprove_proved_slices: num_slices, jstprove_verified_slices: 0, }, packaging_type: Some("dsperse".into()), source_path: Some("/tmp/model.onnx".into()), run_directory: Some("/tmp/runs/run_0".into()), model_path: Some("/tmp/model.onnx".into()), } } fn bench_roundtrip Deserialize<'de>>( c: &mut Criterion, name: &str, value: &T, ) { let json_bytes = serde_json::to_vec(value).unwrap(); let msgpack_bytes = rmp_serde::to_vec_named(value).unwrap(); let group_name = format!( "{name} (json={}, msgpack={})", json_bytes.len(), msgpack_bytes.len() ); let mut group = c.benchmark_group(&group_name); group.bench_function("json_serialize", |b| { b.iter(|| serde_json::to_vec(black_box(value)).unwrap()); }); group.bench_function("msgpack_serialize", |b| { b.iter(|| rmp_serde::to_vec_named(black_box(value)).unwrap()); }); group.bench_function("json_deserialize", |b| { b.iter(|| serde_json::from_slice::(black_box(&json_bytes)).unwrap()); }); group.bench_function("msgpack_deserialize", |b| { b.iter(|| rmp_serde::from_slice::(black_box(&msgpack_bytes)).unwrap()); }); group.finish(); } fn serialization_benchmarks(c: &mut Criterion) { let small_model = make_model_metadata(4); let large_model = make_model_metadata(64); let small_run = make_run_metadata(4); let large_run = make_run_metadata(64); bench_roundtrip(c, "ModelMetadata_4slices", &small_model); bench_roundtrip(c, "ModelMetadata_64slices", &large_model); bench_roundtrip(c, "RunMetadata_4slices", &small_run); bench_roundtrip(c, "RunMetadata_64slices", &large_run); } criterion_group!(benches, serialization_benchmarks); criterion_main!(benches); ================================================ FILE: crates/dsperse/build.rs ================================================ fn main() { prost_build::Config::new() .compile_protos(&["proto/onnx.proto"], &["proto/"]) .expect("Failed to compile ONNX proto"); let git_rev = std::process::Command::new("git") .args(["rev-parse", "--short", "HEAD"]) .output() .ok() .filter(|o| o.status.success()) .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()); if let Some(ref rev) = git_rev { println!("cargo:rustc-env=DSPERSE_GIT_REV={rev}"); } let pkg_version = std::env::var("CARGO_PKG_VERSION").unwrap(); let display_version = match (pkg_version.as_str(), &git_rev) { ("0.0.0", Some(rev)) => format!("dev-{rev}"), ("0.0.0", None) => "dev".to_string(), (v, Some(rev)) => format!("{v}+{rev}"), (v, None) => v.to_string(), }; println!("cargo:rustc-env=DSPERSE_DISPLAY_VERSION={display_version}"); if let Some(output) = std::process::Command::new("git") .args(["rev-parse", "--git-path", "HEAD"]) .output() .ok() .filter(|o| o.status.success()) { let head_path = String::from_utf8_lossy(&output.stdout).trim().to_string(); println!("cargo:rerun-if-changed={head_path}"); } if let Some(output) = std::process::Command::new("git") .args(["symbolic-ref", "-q", "HEAD"]) .output() .ok() .filter(|o| o.status.success()) { let head_ref = String::from_utf8_lossy(&output.stdout).trim().to_string(); if let Some(output) = std::process::Command::new("git") .args(["rev-parse", "--git-path", &head_ref]) .output() .ok() .filter(|o| o.status.success()) { let ref_path = String::from_utf8_lossy(&output.stdout).trim().to_string(); println!("cargo:rerun-if-changed={ref_path}"); } } } ================================================ FILE: crates/dsperse/proto/onnx.proto ================================================ // // WARNING: This file is automatically generated! Please edit onnx.in.proto. // // SPDX-License-Identifier: Apache-2.0 syntax = "proto3"; package onnx; // Overview // // ONNX is an open specification that is comprised of the following components: // // 1) A definition of an extensible computation graph model. // 2) Definitions of standard data types. // 3) Definitions of built-in operators. // // This document describes the syntax of models and their computation graphs, // as well as the standard data types. Together, they are referred to as the ONNX // Intermediate Representation, or 'IR' for short. // // The normative semantic specification of the ONNX IR is found in docs/IR.md. // Definitions of the built-in neural network operators may be found in docs/Operators.md. // Notes // // Protobuf compatibility // // To simplify framework compatibility, ONNX is defined using the subset of protobuf // that is compatible with both protobuf v2 and v3. This means that we do not use any // protobuf features that are only available in one of the two versions. // // Here are the most notable contortions we have to carry out to work around // these limitations: // // - No 'map' (added protobuf 3.0). We instead represent mappings as lists // of key-value pairs, where order does not matter and duplicates // are not allowed. // Versioning // // ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md // // To be compatible with both proto2 and proto3, we will use a version number // that is not defined by the default value but an explicit enum number. enum Version { // proto3 requires the first enum value to be zero. // We add this just to appease the compiler. _START_VERSION = 0; // The version field is always serialized and we will use it to store the // version that the graph is generated from. This helps us set up version // control. // For the IR, we are using simple numbers starting with 0x00000001, // which was the version we published on Oct 10, 2017. IR_VERSION_2017_10_10 = 0x0000000000000001; // IR_VERSION 2 published on Oct 30, 2017 // - Added type discriminator to AttributeProto to support proto3 users IR_VERSION_2017_10_30 = 0x0000000000000002; // IR VERSION 3 published on Nov 3, 2017 // - For operator versioning: // - Added new message OperatorSetIdProto // - Added opset_import in ModelProto // - For vendor extensions, added domain in NodeProto IR_VERSION_2017_11_3 = 0x0000000000000003; // IR VERSION 4 published on Jan 22, 2019 // - Relax constraint that initializers should be a subset of graph inputs // - Add type BFLOAT16 IR_VERSION_2019_1_22 = 0x0000000000000004; // IR VERSION 5 published on March 18, 2019 // - Add message TensorAnnotation. // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. IR_VERSION_2019_3_18 = 0x0000000000000005; // IR VERSION 6 published on Sep 19, 2019 // - Add support for sparse tensor constants stored in model. // - Add message SparseTensorProto // - Add sparse initializers IR_VERSION_2019_9_19 = 0x0000000000000006; // IR VERSION 7 published on May 8, 2020 // - Add support to allow function body graph to rely on multiple external opreator sets. // - Add a list to promote inference graph's initializers to global and // mutable variables. Global variables are visible in all graphs of the // stored models. // - Add message TrainingInfoProto to store initialization // method and training algorithm. The execution of TrainingInfoProto // can modify the values of mutable variables. // - Implicitly add inference graph into each TrainingInfoProto's algorithm. IR_VERSION_2020_5_8 = 0x0000000000000007; // IR VERSION 8 published on July 30, 2021 // Introduce TypeProto.SparseTensor // Introduce TypeProto.Optional // Added a list of FunctionProtos local to the model // Deprecated since_version and operator status from FunctionProto IR_VERSION_2021_7_30 = 0x0000000000000008; // IR VERSION 9 published on May 5, 2023 // Added AttributeProto to FunctionProto so that default attribute values can be set. // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. IR_VERSION_2023_5_5 = 0x0000000000000009; // IR VERSION 10 published on March 25, 2024 // Added UINT4, INT4. IR_VERSION_2024_3_25 = 0x000000000000000A; // IR VERSION 11 published on TBD // Added FLOAT4E2M1, multi-device protobuf classes. IR_VERSION = 0x000000000000000B; } // Attributes // // A named attribute containing either singular float, integer, string, graph, // and tensor values, or repeated float, integer, string, graph, and tensor values. // An AttributeProto MUST contain the name field, and *only one* of the // following content fields, effectively enforcing a C/C++ union equivalent. message AttributeProto { reserved 12, 16 to 19; reserved "v"; // Note: this enum is structurally identical to the OpSchema::AttrType // enum defined in schema.h. If you rev one, you likely need to rev the other. enum AttributeType { UNDEFINED = 0; FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; GRAPH = 5; SPARSE_TENSOR = 11; TYPE_PROTO = 13; FLOATS = 6; INTS = 7; STRINGS = 8; TENSORS = 9; GRAPHS = 10; SPARSE_TENSORS = 12; TYPE_PROTOS = 14; } // The name field MUST be present for this version of the IR. string name = 1; // namespace Attribute // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. // In this case, this AttributeProto does not contain data, and it's a reference of attribute // in parent scope. // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. string ref_attr_name = 21; // A human-readable documentation for this attribute. Markdown is allowed. string doc_string = 13; // The type field MUST be present for this version of the IR. // For 0.0.1 versions of the IR, this field was not defined, and // implementations needed to use has_field heuristics to determine // which value field was in use. For IR_VERSION 0.0.2 or later, this // field MUST be set and match the f|i|s|t|... field in use. This // change was made to accommodate proto3 implementations. AttributeType type = 20; // discriminator that indicates which field below is in use // Exactly ONE of the following fields must be present for this version of the IR float f = 2; // float int64 i = 3; // int bytes s = 4; // UTF-8 string TensorProto t = 5; // tensor value GraphProto g = 6; // graph SparseTensorProto sparse_tensor = 22; // sparse tensor value // Do not use field below, it's deprecated. // optional ValueProto v = 12; // value - subsumes everything but graph TypeProto tp = 14; // type proto repeated float floats = 7; // list of floats repeated int64 ints = 8; // list of ints repeated bytes strings = 9; // list of UTF-8 strings repeated TensorProto tensors = 10; // list of tensors repeated GraphProto graphs = 11; // list of graph repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors repeated TypeProto type_protos = 15;// list of type protos } // Defines information on value, including the name, the type, and // the shape of the value. message ValueInfoProto { // This field MUST be present in this version of the IR. string name = 1; // namespace Value // This field MUST be present in this version of the IR for // inputs and outputs of the top-level graph. TypeProto type = 2; // A human-readable documentation for this value. Markdown is allowed. string doc_string = 3; // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 4; } // Nodes // // Computation graphs are made up of a DAG of nodes, which represent what is // commonly called a "layer" or "pipeline stage" in machine learning frameworks. // // For example, it can be a node of type "Conv" that takes in an image, a filter // tensor and a bias tensor, and produces the convolved output. message NodeProto { repeated string input = 1; // namespace Value repeated string output = 2; // namespace Value // An optional identifier for this node in a graph. // This field MAY be absent in this version of the IR. string name = 3; // namespace Node // The symbolic identifier of the Operator to execute. string op_type = 4; // namespace Operator // The domain of the OperatorSet that specifies the operator named by op_type. string domain = 7; // namespace Domain // Overload identifier, used only to map this to a model-local function. string overload = 8; // Additional named attributes. repeated AttributeProto attribute = 5; // A human-readable documentation for this node. Markdown is allowed. string doc_string = 6; // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 9; // Configuration of multi-device annotations. repeated NodeDeviceConfigurationProto device_configurations = 10; } // IntIntListEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message IntIntListEntryProto { int64 key = 1; repeated int64 value = 2; }; // Multi-device configuration proto for NodeProto. message NodeDeviceConfigurationProto { // This field MUST be present for this version of the IR. // ID of the configuration. MUST match the name of a DeviceConfigurationProto. string configuration_id = 1; // Sharding spec for the node. repeated ShardingSpecProto sharding_spec = 2; // Pipeline stage of this node. int32 pipeline_stage = 3; } // ShardingSpecProto: This describes the sharding spec for a specific // input or output tensor of a node. message ShardingSpecProto { // This field MUST be present for this version of the IR. // Identifies the input or output of the node that is being sharded. // Required to match a name specified in the node's input or output list of ValueInfoProtos. // It is called `logical tensor` in subsequent descriptions. string tensor_name = 1; // The following is the list of devices across which the logical // tensor is sharded or replicated. repeated int64 device = 2; // Each element v in above field devices may represent either a // device or a set of devices (when we want the same shard/tensor // to be replicated across a subset of devices), as indicated by // the following optional map. If the map contains an entry for v, // then v represents a device group, and the map indicates the set // of devices in that group. repeated IntIntListEntryProto index_to_device_group_map = 3; // The following is the sharded-shape of the tensor, consisting of // the sharding-spec for each axis of the tensor. repeated ShardedDimProto sharded_dim = 4; } // ShardedDimProto: This describes the sharding spec for a single // axis of a sharded tensor. message ShardedDimProto { // This field MUST be present for this version of the IR. // The axis this sharding corresponds to. Must be in the range of // [-r, r - 1], where r is the rank of the tensor. Negative axis values means // counting from the back. int64 axis = 1; // Describes how the tensor on the provided axis is sharded. // The common-case is described by a single instance of SimpleShardedDimProto. // Multiple instances can be used to handle cases where a sharded // tensor is reshaped, fusing multiple axes into one. repeated SimpleShardedDimProto simple_sharding = 2; } // SimpleShardedDimProto: Indicates that N blocks are divided into M shards. // N is allowed to be symbolic where M is required to be a constant. message SimpleShardedDimProto { // Dimension value to be sharded. oneof dim { int64 dim_value = 1; string dim_param = 2; } // This field MUST be present for this version of the IR. // Number of shards to split dim into. int64 num_shards = 3; } // Training information // TrainingInfoProto stores information for training a model. // In particular, this defines two functionalities: an initialization-step // and a training-algorithm-step. Initialization resets the model // back to its original state as if no training has been performed. // Training algorithm improves the model based on input data. // // The semantics of the initialization-step is that the initializers // in ModelProto.graph and in TrainingInfoProto.algorithm are first // initialized as specified by the initializers in the graph, and then // updated by the "initialization_binding" in every instance in // ModelProto.training_info. // // The field "algorithm" defines a computation graph which represents a // training algorithm's step. After the execution of a // TrainingInfoProto.algorithm, the initializers specified by "update_binding" // may be immediately updated. If the targeted training algorithm contains // consecutive update steps (such as block coordinate descent methods), // the user needs to create a TrainingInfoProto for each step. message TrainingInfoProto { // This field describes a graph to compute the initial tensors // upon starting the training process. Initialization graph has no input // and can have multiple outputs. Usually, trainable tensors in neural // networks are randomly initialized. To achieve that, for each tensor, // the user can put a random number operator such as RandomNormal or // RandomUniform in TrainingInfoProto.initialization.node and assign its // random output to the specific tensor using "initialization_binding". // This graph can also set the initializers in "algorithm" in the same // TrainingInfoProto; a use case is resetting the number of training // iteration to zero. // // By default, this field is an empty graph and its evaluation does not // produce any output. Thus, no initializer would be changed by default. GraphProto initialization = 1; // This field represents a training algorithm step. Given required inputs, // it computes outputs to update initializers in its own or inference graph's // initializer lists. In general, this field contains loss node, gradient node, // optimizer node, increment of iteration count. // // An execution of the training algorithm step is performed by executing the // graph obtained by combining the inference graph (namely "ModelProto.graph") // and the "algorithm" graph. That is, the actual // input/initializer/output/node/value_info/sparse_initializer list of // the training graph is the concatenation of // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" // in that order. This combined graph must satisfy the normal ONNX conditions. // Now, let's provide a visualization of graph combination for clarity. // Let the inference graph (i.e., "ModelProto.graph") be // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d // and the "algorithm" graph be // tensor_d -> Add -> tensor_e // The combination process results // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e // // Notice that an input of a node in the "algorithm" graph may reference the // output of a node in the inference graph (but not the other way round). Also, inference // node cannot reference inputs of "algorithm". With these restrictions, inference graph // can always be run independently without training information. // // By default, this field is an empty graph and its evaluation does not // produce any output. Evaluating the default training step never // update any initializers. GraphProto algorithm = 2; // This field specifies the bindings from the outputs of "initialization" to // some initializers in "ModelProto.graph.initializer" and // the "algorithm.initializer" in the same TrainingInfoProto. // See "update_binding" below for details. // // By default, this field is empty and no initializer would be changed // by the execution of "initialization". repeated StringStringEntryProto initialization_binding = 3; // Gradient-based training is usually an iterative procedure. In one gradient // descent iteration, we apply // // x = x - r * g // // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is // gradient of "x" with respect to a chosen loss. To avoid adding assignments // into the training graph, we split the update equation into // // y = x - r * g // x = y // // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To // tell that "y" should be assigned to "x", the field "update_binding" may // contain a key-value pair of strings, "x" (key of StringStringEntryProto) // and "y" (value of StringStringEntryProto). // For a neural network with multiple trainable (mutable) tensors, there can // be multiple key-value pairs in "update_binding". // // The initializers appears as keys in "update_binding" are considered // mutable variables. This implies some behaviors // as described below. // // 1. We have only unique keys in all "update_binding"s so that two // variables may not have the same name. This ensures that one // variable is assigned up to once. // 2. The keys must appear in names of "ModelProto.graph.initializer" or // "TrainingInfoProto.algorithm.initializer". // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". // 4. Mutable variables are initialized to the value specified by the // corresponding initializer, and then potentially updated by // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. // // This field usually contains names of trainable tensors // (in ModelProto.graph), optimizer states such as momentums in advanced // stochastic gradient methods (in TrainingInfoProto.graph), // and number of training iterations (in TrainingInfoProto.graph). // // By default, this field is empty and no initializer would be changed // by the execution of "algorithm". repeated StringStringEntryProto update_binding = 4; } // Models // // ModelProto is a top-level file/container format for bundling a ML model and // associating its computation graph with metadata. // // The semantics of the model are described by the associated GraphProto's. message ModelProto { // The version of the IR this model targets. See Version enum above. // This field MUST be present. int64 ir_version = 1; // The OperatorSets this model relies on. // All ModelProtos MUST have at least one entry that // specifies which version of the ONNX OperatorSet is // being imported. // // All nodes in the ModelProto's graph will bind against the operator // with the same-domain/same-op_type operator with the HIGHEST version // in the referenced operator sets. repeated OperatorSetIdProto opset_import = 8; // The name of the framework or tool used to generate this model. // This field SHOULD be present to indicate which implementation/tool/framework // emitted the model. string producer_name = 2; // The version of the framework or tool used to generate this model. // This field SHOULD be present to indicate which implementation/tool/framework // emitted the model. string producer_version = 3; // Domain name of the model. // We use reverse domain names as name space indicators. For example: // `com.facebook.fair` or `com.microsoft.cognitiveservices` // // Together with `model_version` and GraphProto.name, this forms the unique identity of // the graph. string domain = 4; // The version of the graph encoded. See Version enum below. int64 model_version = 5; // A human-readable documentation for this model. Markdown is allowed. string doc_string = 6; // The parameterized graph that is evaluated to execute the model. GraphProto graph = 7; // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 14; // Training-specific information. Sequentially executing all stored // `TrainingInfoProto.algorithm`s and assigning their outputs following // the corresponding `TrainingInfoProto.update_binding`s is one training // iteration. Similarly, to initialize the model // (as if training hasn't happened), the user should sequentially execute // all stored `TrainingInfoProto.initialization`s and assigns their outputs // using `TrainingInfoProto.initialization_binding`s. // // If this field is empty, the training behavior of the model is undefined. repeated TrainingInfoProto training_info = 20; // A list of function protos local to the model. // // The (domain, name, overload) tuple must be unique across the function protos in this list. // In case of any conflicts the behavior (whether the model local functions are given higher priority, // or standard operator sets are given higher priotity or this is treated as error) is defined by // the runtimes. // // The operator sets imported by FunctionProto should be compatible with the ones // imported by ModelProto and other model local FunctionProtos. // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto // or by 2 FunctionProtos then versions for the operator set may be different but, // the operator schema returned for op_type, domain, version combination // for both the versions should be same for every node in the function body. // // One FunctionProto can reference other FunctionProto in the model, however, recursive reference // is not allowed. repeated FunctionProto functions = 25; // Describes different target configurations for a multi-device use case. // A model MAY describe multiple multi-device configurations for execution. repeated DeviceConfigurationProto configuration = 26; }; // DeviceConfigurationProto describes a multi-device configuration for a model. message DeviceConfigurationProto { // This field MUST be present for this version of the IR. // Name of the configuration. string name = 1; // This field MUST be present for this version of the IR. // Number of devices inside this configuration. int32 num_devices = 2; // Optional names of the devices. MUST be length of num_devices if provided. repeated string device = 3; } // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto { string key = 1; string value = 2; }; message TensorAnnotation { string tensor_name = 1; // pairs to annotate tensor specified by above. // The keys used in the mapping below must be pre-defined in ONNX spec. // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as // quantization parameter keys. repeated StringStringEntryProto quant_parameter_tensor_names = 2; } // Graphs // // A graph defines the computational logic of a model and is comprised of a parameterized // list of nodes that form a directed acyclic graph based on their inputs and outputs. // This is the equivalent of the "network" or "graph" in many deep learning // frameworks. message GraphProto { // The nodes in the graph, sorted topologically. repeated NodeProto node = 1; // The name of the graph. string name = 2; // namespace Graph // A list of named tensor values, used to specify constant inputs of the graph. // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. // The name MUST be unique across both initializer and sparse_initializer, // but the name MAY also appear in the input list. repeated TensorProto initializer = 5; // Initializers (see above) stored in sparse format. repeated SparseTensorProto sparse_initializer = 15; // A human-readable documentation for this graph. Markdown is allowed. string doc_string = 10; // The inputs and outputs of the graph. repeated ValueInfoProto input = 11; repeated ValueInfoProto output = 12; // Information for the values in the graph. The ValueInfoProto.name's // must be distinct. It is optional for a value to appear in value_info list. repeated ValueInfoProto value_info = 13; // This field carries information to indicate the mapping among a tensor and its // quantization parameter tensors. For example: // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. repeated TensorAnnotation quantization_annotation = 14; // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 16; reserved 3, 4, 6 to 9; reserved "ir_version", "producer_version", "producer_tag", "domain"; } // Tensors // // A serialized tensor value. message TensorProto { enum DataType { UNDEFINED = 0; // Basic types. FLOAT = 1; // float UINT8 = 2; // uint8_t INT8 = 3; // int8_t UINT16 = 4; // uint16_t INT16 = 5; // int16_t INT32 = 6; // int32_t INT64 = 7; // int64_t STRING = 8; // string BOOL = 9; // bool // IEEE754 half-precision floating-point format (16 bits wide). // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. FLOAT16 = 10; DOUBLE = 11; UINT32 = 12; UINT64 = 13; COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components // Non-IEEE floating-point format based on IEEE754 single-precision // floating-point number truncated to 16 bits. // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. BFLOAT16 = 16; // Non-IEEE floating-point format based on papers // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. // The computation usually happens inside a block quantize / dequantize // fused by the runtime. FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero // 4-bit integer data types UINT4 = 21; // Unsigned integer in range [0, 15] INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation // 4-bit floating point data types FLOAT4E2M1 = 23; // Future extensions go here. } // The shape of the tensor. repeated int64 dims = 1; // The data type of the tensor. // This field MUST have a valid TensorProto.DataType value int32 data_type = 2; // For very large tensors, we may want to store them in chunks, in which // case the following fields will specify the segment that is stored in // the current TensorProto. message Segment { int64 begin = 1; int64 end = 2; } Segment segment = 3; // Tensor content must be organized in row-major order. // // Depending on the data_type field, exactly one of the fields below with // name ending in _data is used to store the elements of the tensor. // For float and complex64 values // Complex64 tensors are encoded as a single array of floats, // with the real components appearing in odd numbered positions, // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. repeated float float_data = 4 [packed = true]; // For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4: // - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer // representation before being written to the buffer. // - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte. // The first element is stored in the 4 least significant bits (LSB), // and the second element is stored in the 4 most significant bits (MSB). // // Consequently: // - For data types with a bit-width of 8 or greater, each `int32_data` stores one element. // - For 4-bit data types, each `int32_data` stores two elements. // // When this field is present, the data_type field MUST be // INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1 repeated int32 int32_data = 5 [packed = true]; // For strings. // Each element of string_data is a UTF-8 encoded Unicode // string. No trailing null, no leading BOM. The protobuf "string" // scalar type is not used to match ML community conventions. // When this field is present, the data_type field MUST be STRING repeated bytes string_data = 6; // For int64. // When this field is present, the data_type field MUST be INT64 repeated int64 int64_data = 7 [packed = true]; // Optionally, a name for the tensor. string name = 8; // namespace Value // A human-readable documentation for this tensor. Markdown is allowed. string doc_string = 12; // Serializations can either use one of the fields above, or use this // raw bytes field. The only exception is the string case, where one is // required to store the content in the repeated bytes string_data field. // // When this raw_data field is used to store tensor value, elements MUST // be stored in as fixed-width, little-endian order. // Floating-point data types MUST be stored in IEEE 754 format. // Complex64 elements must be written as two consecutive FLOAT values, real component first. // Complex128 elements must be written as two consecutive DOUBLE values, real component first. // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). // uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB. // // Note: the advantage of specific field rather than the raw_data field is // that in some cases (e.g. int data), protobuf does a better packing via // variable length storage, and may lead to smaller binary footprint. // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED bytes raw_data = 9; // Data can be stored inside the protobuf file using type-specific fields or raw_data. // Alternatively, raw bytes data can be stored in an external file, using the external_data field. // external_data stores key-value pairs describing data location. Recognized keys are: // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX // protobuf model was stored // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. // - "length" (optional) - number of bytes containing data. Integer stored as string. // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. repeated StringStringEntryProto external_data = 13; // Location of the data for this tensor. MUST be one of: // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. // - EXTERNAL - data stored in an external location as described by external_data field. enum DataLocation { DEFAULT = 0; EXTERNAL = 1; } // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. DataLocation data_location = 14; // For double // Complex128 tensors are encoded as a single array of doubles, // with the real components appearing in odd numbered positions, // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 repeated double double_data = 10 [packed = true]; // For uint64 and uint32 values // When this field is present, the data_type field MUST be // UINT32 or UINT64 repeated uint64 uint64_data = 11 [packed = true]; // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 16; } // A serialized sparse-tensor value message SparseTensorProto { // The sequence of non-default values are encoded as a tensor of shape [NNZ]. // The default-value is zero for numeric tensors, and empty-string for string tensors. // values must have a non-empty name present which serves as a name for SparseTensorProto // when used in sparse_initializer list. TensorProto values = 1; // The indices of the non-default values, which may be stored in one of two formats. // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value // corresponding to the j-th index of the i-th value (in the values tensor). // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value // must be the linearized-index of the i-th value (in the values tensor). // The linearized-index can be converted into an index tuple (k_1,...,k_rank) // using the shape provided below. // The indices must appear in ascending order without duplication. // In the first format, the ordering is lexicographic-ordering: // e.g., index-value [1,4] must appear before [2,1] TensorProto indices = 2; // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] repeated int64 dims = 3; } // Defines a tensor shape. A dimension can be either an integer value // or a symbolic variable. A symbolic variable represents an unknown // dimension. message TensorShapeProto { message Dimension { oneof value { int64 dim_value = 1; string dim_param = 2; // namespace Shape }; // Standard denotation can optionally be used to denote tensor // dimensions with standard semantic descriptions to ensure // that operations are applied to the correct axis of a tensor. // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition // for pre-defined dimension denotations. string denotation = 3; }; repeated Dimension dim = 1; } // Types // // The standard ONNX data types. message TypeProto { message Tensor { // This field MUST NOT have the value of UNDEFINED // This field MUST have a valid TensorProto.DataType value // This field MUST be present for this version of the IR. int32 elem_type = 1; TensorShapeProto shape = 2; } // repeated T message Sequence { // The type and optional shape of each element of the sequence. // This field MUST be present for this version of the IR. TypeProto elem_type = 1; }; // map message Map { // This field MUST have a valid TensorProto.DataType value // This field MUST be present for this version of the IR. // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING int32 key_type = 1; // This field MUST be present for this version of the IR. TypeProto value_type = 2; }; // wrapper for Tensor, Sequence, or Map message Optional { // The type and optional shape of the element wrapped. // This field MUST be present for this version of the IR. // Possible values correspond to OptionalProto.DataType enum TypeProto elem_type = 1; }; message SparseTensor { // This field MUST NOT have the value of UNDEFINED // This field MUST have a valid TensorProto.DataType value // This field MUST be present for this version of the IR. int32 elem_type = 1; TensorShapeProto shape = 2; } oneof value { // The type of a tensor. Tensor tensor_type = 1; // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values // as input and output to graphs and nodes. These types are needed to naturally // support classical ML operators. DNN operators SHOULD restrict their input // and output types to tensors. // The type of a sequence. Sequence sequence_type = 4; // The type of a map. Map map_type = 5; // The type of an optional. Optional optional_type = 9; // Type of the sparse tensor SparseTensor sparse_tensor_type = 8; } // An optional denotation can be used to denote the whole // type with a standard semantic description as to what is // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition // for pre-defined type denotations. string denotation = 6; } // Operator Sets // // OperatorSets are uniquely identified by a (domain, opset_version) pair. message OperatorSetIdProto { // The domain of the operator set being identified. // The empty string ("") or absence of this field implies the operator // set that is defined as part of the ONNX specification. // This field MUST be present in this version of the IR when referring to any other operator set. string domain = 1; // The version of the operator set being identified. // This field MUST be present in this version of the IR. int64 version = 2; } // Operator/function status. enum OperatorStatus { EXPERIMENTAL = 0; STABLE = 1; } message FunctionProto { // The name of the function, similar to op_type in NodeProto. // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. string name = 1; // Deprecated since IR Version 8 // optional int64 since_version = 2; reserved 2; reserved "since_version"; // Deprecated since IR Version 8 // optional OperatorStatus status = 3; reserved 3; reserved "status"; // The inputs and outputs of the function. repeated string input = 4; repeated string output = 5; // The attribute parameters of the function. // It is for function parameters without default values. repeated string attribute = 6; // The attribute protos of the function. // It is for function attributes with default values. // A function attribute shall be represented either as // a string attribute or an AttributeProto, not both. repeated AttributeProto attribute_proto = 11; // The nodes in the function. repeated NodeProto node = 7; // A human-readable documentation for this function. Markdown is allowed. string doc_string = 8; // The OperatorSets this function body (graph) relies on. // // All nodes in the function body (graph) will bind against the operator // with the same-domain/same-op_type operator with the HIGHEST version // in the referenced operator sets. This means at most one version can be relied // for one domain. // // The operator sets imported by FunctionProto should be compatible with the ones // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto // and ModelProto then versions for the operator set may be different but, // the operator schema returned for op_type, domain, version combination // for both the versions should be same. repeated OperatorSetIdProto opset_import = 9; // The domain which this function belongs to. // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. string domain = 10; // The overload identifier of the function. // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. string overload = 13; // Information for the values in the function. The ValueInfoProto.name's // must be distinct and refer to names in the function (including inputs, // outputs, and intermediate values). It is optional for a value to appear // in value_info list. repeated ValueInfoProto value_info = 12; // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 14; } // For using protobuf-lite option optimize_for = LITE_RUNTIME; ================================================ FILE: crates/dsperse/src/backend/jstprove.rs ================================================ use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; pub use jstprove_circuits::api::ExtractedOutputType as ExtractedOutput; pub use jstprove_circuits::api::ProofConfigType as ProofConfig; pub use jstprove_circuits::api::StampedProofConfigType as StampedProofConfig; pub use jstprove_circuits::api::VerifiedOutputType as VerifiedOutput; use jstprove_circuits::api::{ self, ArchitectureType as Architecture, CircuitParamsType as CircuitParams, CompiledCircuitType as CompiledCircuit, WANDBType as WANDB, }; use jstprove_circuits::runner::schema::WitnessRequest; use crate::error::{DsperseError, Result}; use super::traits::ProofBackend; #[derive(Debug)] pub struct JstproveBackend { compress: bool, bundle_cache: Mutex>>, } impl Default for JstproveBackend { fn default() -> Self { Self { compress: true, bundle_cache: Mutex::new(HashMap::new()), } } } impl JstproveBackend { pub fn new() -> Self { Self::default() } pub fn with_compress(mut self, compress: bool) -> Self { self.compress = compress; self } pub fn compress(&self) -> bool { self.compress } pub fn load_bundle_cached(&self, path: &Path) -> Result> { let key = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); let mut cache = self .bundle_cache .lock() .map_err(|e| DsperseError::Backend(format!("bundle cache lock poisoned: {e}")))?; if let Some(bundle) = cache.get(&key) { return Ok(Arc::clone(bundle)); } let bundle = Arc::new(load_bundle(path)?); cache.insert(key, Arc::clone(&bundle)); Ok(bundle) } pub fn clear_cache(&self) { let mut cache = match self.bundle_cache.lock() { Ok(cache) => cache, Err(e) => { tracing::warn!("bundle cache lock poisoned on clear: {e}"); e.into_inner() } }; let count = cache.len(); cache.clear(); tracing::debug!(cleared = count, "bundle cache cleared"); } /// Evict cached bundles whose canonical path starts with the given /// prefix. Used by callers that want to drop a model's entries /// without clearing the entire cache. pub fn evict_cache_by_prefix(&self, prefix: &Path) { let mut cache = match self.bundle_cache.lock() { Ok(cache) => cache, Err(e) => { tracing::warn!("bundle cache lock poisoned on evict: {e}"); e.into_inner() } }; let before = cache.len(); cache.retain(|k, _| !k.starts_with(prefix)); let evicted = before - cache.len(); if evicted > 0 { tracing::info!( prefix = %prefix.display(), evicted, remaining = cache.len(), "evicted bundle cache entries" ); } } /// Resolve the proof config for a freshly loaded bundle. Errors if /// the bundle does not carry a stamped proof config or if the /// stamped version does not match the current spec, so callers can /// fail fast on legacy or incompatible bundles instead of running /// the wrong prover. fn resolve_proof_config(bundle: &CompiledCircuit) -> Result { let stamped = bundle .metadata .as_ref() .and_then(|m| m.proof_config) .ok_or_else(|| { DsperseError::Backend( "circuit bundle has no stamped proof_config; recompile with a stamping prover" .into(), ) })?; stamped .ensure_current() .map_err(|e| DsperseError::Backend(format!("incompatible bundle: {e}")))?; Ok(stamped.config) } /// Resolve the proof config without touching the circuit or /// witness-solver blobs. Reads only `manifest.msgpack`, which is /// kilobytes versus the tens of megabytes a full bundle load /// pulls in. Falls back to `resolve_proof_config` on a full /// bundle load if the manifest is missing the stamp so callers /// still get the same "no stamped proof_config" error path for /// legacy bundles rather than a confusing deserialization /// failure. fn resolve_proof_config_from_manifest(&self, circuit_path: &Path) -> Result { match jstprove_io::bundle::read_bundle_metadata::(circuit_path) { Ok((Some(params), _)) => { let stamped = params.proof_config.ok_or_else(|| { DsperseError::Backend( "circuit bundle has no stamped proof_config; recompile with a stamping prover" .into(), ) })?; stamped .ensure_current() .map_err(|e| DsperseError::Backend(format!("incompatible bundle: {e}")))?; Ok(stamped.config) } Ok((None, _)) => { let bundle = self.load_bundle_cached(circuit_path)?; Self::resolve_proof_config(&bundle) } Err(e) => { // Surface the manifest-read failure so operators // investigating a slow verify path or a legacy // bundle layout can tell the fast path missed // rather than silently eating a parse / IO error. tracing::debug!( path = %circuit_path.display(), error = %e, "manifest-only proof_config read failed; falling back to full bundle load" ); let bundle = self.load_bundle_cached(circuit_path)?; Self::resolve_proof_config(&bundle) } } } pub fn compile( &self, circuit_path: &Path, config: ProofConfig, params: CircuitParams, architecture: Architecture, wandb: WANDB, ) -> Result<()> { let circuit_path_str = circuit_path .to_str() .ok_or_else(|| DsperseError::Backend("non-UTF8 circuit path".into()))?; api::compile( circuit_path_str, config, params, architecture, wandb, self.compress, ) .map_err(|e| DsperseError::Backend(format!("compile: {e}")))?; let key = circuit_path .canonicalize() .unwrap_or_else(|_| circuit_path.to_path_buf()); self.bundle_cache .lock() .map_err(|e| DsperseError::Backend(format!("bundle cache lock poisoned: {e}")))? .remove(&key); Ok(()) } pub fn witness( &self, circuit_path: &Path, input_json: &[u8], output_json: &[u8], ) -> Result> { let bundle = self.load_bundle_cached(circuit_path)?; let config = Self::resolve_proof_config(&bundle)?; let req = WitnessRequest { circuit: bundle.circuit.clone(), witness_solver: bundle.witness_solver.clone(), inputs: input_json.to_vec(), outputs: output_json.to_vec(), metadata: bundle.metadata.clone(), }; let result = api::witness(config, &req, self.compress) .map_err(|e| DsperseError::Backend(format!("witness: {e}")))?; Ok(result.witness) } pub fn witness_f64( &self, circuit_path: &Path, activations: &[f64], initializers: &[(Vec, Vec)], ) -> Result> { let bundle = self.load_bundle_cached(circuit_path)?; let config = Self::resolve_proof_config(&bundle)?; let params = bundle.metadata.as_ref().ok_or_else(|| { DsperseError::Backend( "circuit bundle missing metadata (required for quantization)".into(), ) })?; let result = api::witness_f64( config, &bundle.circuit, &bundle.witness_solver, params, activations, initializers, self.compress, ) .map_err(|e| DsperseError::Backend(format!("witness_f64: {e}")))?; Ok(result.witness) } pub fn load_params(&self, circuit_path: &Path) -> Result> { let bundle = self.load_bundle_cached(circuit_path)?; Ok(bundle.metadata.clone()) } pub fn prove(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result> { let bundle = self.load_bundle_cached(circuit_path)?; let config = Self::resolve_proof_config(&bundle)?; api::prove(config, &bundle.circuit, witness_bytes, self.compress) .map_err(|e| DsperseError::Backend(format!("prove: {e}"))) } pub fn extract_outputs( &self, witness_bytes: &[u8], num_model_inputs: usize, ) -> Result> { Ok(self .extract_outputs_full(witness_bytes, num_model_inputs)? .outputs) } /// Full extracted output bundle: inputs, outputs, and the /// witness-stamped scale parameters. Holographic verifiers call /// this after `verify_holographic` because the holographic /// verify path does not reach through `verify_and_extract`, yet /// the validator still needs the declared inputs (to cross-check /// against what it sent) and the scale fields (to report the /// same `VerifiedOutput` shape the non-holographic path /// produces). Keeping `extract_outputs` as a thin wrapper /// preserves the existing `Vec` contract for callers that /// only want the outputs. pub fn extract_outputs_full( &self, witness_bytes: &[u8], num_model_inputs: usize, ) -> Result { if num_model_inputs == 0 { return Err(DsperseError::Backend( "extract_outputs: num_model_inputs must be > 0".into(), )); } api::extract_outputs(witness_bytes, num_model_inputs) .map_err(|e| DsperseError::Backend(format!("extract_outputs: {e}"))) } pub fn verify( &self, circuit_path: &Path, witness_bytes: &[u8], proof_bytes: &[u8], ) -> Result { let bundle = self.load_bundle_cached(circuit_path)?; let config = Self::resolve_proof_config(&bundle)?; api::verify(config, &bundle.circuit, witness_bytes, proof_bytes) .map_err(|e| DsperseError::Backend(format!("verify: {e}"))) } pub fn verify_and_extract( &self, circuit_path: &Path, witness_bytes: &[u8], proof_bytes: &[u8], num_inputs: usize, expected_inputs: Option<&[f64]>, ) -> Result { let bundle = self.load_bundle_cached(circuit_path)?; let config = Self::resolve_proof_config(&bundle)?; api::verify_and_extract( config, &bundle.circuit, witness_bytes, proof_bytes, num_inputs, expected_inputs, ) .map_err(|e| DsperseError::Backend(format!("verify_and_extract: {e}"))) } /// Run holographic GKR setup against the compiled circuit at /// `circuit_path` and persist the resulting verifying key as /// `vk.bin` inside the bundle directory. The bundle is read from /// the cache, so callers that just compiled the bundle through /// [`Self::compile`] pay only the holographic setup cost on top. /// /// `setup_holographic_vk` only succeeds when the bundle was /// compiled with `ProofConfig::GoldilocksExt4Whir`; the underlying /// jstprove API rejects every other config. /// /// The vk blob is written using the same compression mode as the /// rest of the bundle (`Self::compress`) so /// `jstprove_io::bundle::read_vk_only` can decode it via the /// shared auto-detecting reader. pub fn setup_holographic_vk(&self, circuit_path: &Path) -> Result<()> { let bundle = self.load_bundle_cached(circuit_path)?; let config = Self::resolve_proof_config(&bundle)?; let vk_bytes = api::setup_holographic_vk(config, &bundle.circuit) .map_err(|e| DsperseError::Backend(format!("setup_holographic_vk: {e}")))?; let vk_path = circuit_path.join("vk.bin"); let payload = if self.compress { jstprove_io::compress_bytes(&vk_bytes) .map_err(|e| DsperseError::Backend(format!("compress vk: {e}")))? } else { vk_bytes }; std::fs::write(&vk_path, &payload).map_err(|e| DsperseError::io(e, &vk_path))?; Ok(()) } /// Generate a holographic GKR proof for an existing bundle and /// witness. Like [`Self::setup_holographic_vk`] this requires the /// bundle to have been compiled with /// `ProofConfig::GoldilocksExt4Whir`. pub fn prove_holographic(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result> { let bundle = self.load_bundle_cached(circuit_path)?; let config = Self::resolve_proof_config(&bundle)?; api::prove_holographic(config, &bundle.circuit, witness_bytes) .map_err(|e| DsperseError::Backend(format!("prove_holographic: {e}"))) } /// Verify a holographic GKR proof against the bundle's vk.bin. /// The vk is read independently of the (much larger) circuit /// blob, mirroring the validator-side flow where the verifying /// party only ever ships the vk. pub fn verify_holographic(&self, circuit_path: &Path, proof_bytes: &[u8]) -> Result { // Verifiers only need the vk and the proof config — the // circuit and witness solver blobs are not used downstream. // Skip load_bundle_cached here so validators that only ever // hold vk.bin + manifest.msgpack (the intended light-weight // deployment shape) don't fail with a missing circuit.bin // and don't pay the tens-of-megabytes read cost. let config = self.resolve_proof_config_from_manifest(circuit_path)?; let vk_bytes = jstprove_io::bundle::read_vk_only(circuit_path) .map_err(|e| DsperseError::Backend(format!("read vk: {e}")))?; api::verify_holographic(config, &vk_bytes, proof_bytes) .map_err(|e| DsperseError::Backend(format!("verify_holographic: {e}"))) } } impl ProofBackend for JstproveBackend { fn prove(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result> { self.prove(circuit_path, witness_bytes) } fn verify( &self, circuit_path: &Path, witness_bytes: &[u8], proof_bytes: &[u8], ) -> Result { self.verify(circuit_path, witness_bytes, proof_bytes) } fn witness_f64( &self, circuit_path: &Path, activations: &[f64], initializers: &[(Vec, Vec)], ) -> Result> { self.witness_f64(circuit_path, activations, initializers) } } fn load_bundle(circuit_path: &Path) -> Result { let path_str = circuit_path .to_str() .ok_or_else(|| DsperseError::Backend("non-UTF8 circuit path".into()))?; api::read_circuit_bundle(path_str) .map_err(|e| DsperseError::Backend(format!("read circuit bundle: {e}"))) } pub struct WarmCircuit { bundle: Arc, pub params: CircuitParams, initializers: Vec<(Vec, Vec)>, compress: bool, config: ProofConfig, } impl WarmCircuit { pub fn load( circuit_path: &Path, initializers: Vec<(Vec, Vec)>, backend: &JstproveBackend, ) -> Result { let bundle = backend.load_bundle_cached(circuit_path)?; let config = JstproveBackend::resolve_proof_config(&bundle)?; let params = bundle .metadata .clone() .ok_or_else(|| DsperseError::Backend("circuit bundle missing metadata".into()))?; Ok(Self { bundle, params, initializers, compress: backend.compress(), config, }) } pub fn witness_f64(&self, activations: &[f64]) -> Result> { let result = api::witness_f64( self.config, &self.bundle.circuit, &self.bundle.witness_solver, &self.params, activations, &self.initializers, self.compress, ) .map_err(|e| DsperseError::Backend(format!("witness_f64: {e}")))?; Ok(result.witness) } } #[cfg(test)] mod tests { use super::*; #[test] fn bundle_cache_starts_empty() { let backend = JstproveBackend::default(); let cache = backend.bundle_cache.lock().unwrap(); assert!(cache.is_empty()); } #[test] fn backend_constructs_without_proof_config_state() { let backend = JstproveBackend::default(); assert!(backend.compress()); } #[test] fn clear_cache_on_empty_succeeds() { let backend = JstproveBackend::default(); backend.clear_cache(); let cache = backend.bundle_cache.lock().unwrap(); assert!(cache.is_empty()); } #[test] fn clear_cache_removes_entries() { let backend = JstproveBackend::default(); let dummy = Arc::new(CompiledCircuit { circuit: vec![1, 2, 3], witness_solver: vec![], metadata: None, version: None, }); backend .bundle_cache .lock() .unwrap() .insert(PathBuf::from("/tmp/test-circuit"), dummy); assert_eq!(backend.bundle_cache.lock().unwrap().len(), 1); backend.clear_cache(); assert!(backend.bundle_cache.lock().unwrap().is_empty()); } #[test] fn load_bundle_cached_returns_error_for_missing_path() { let backend = JstproveBackend::default(); let result = backend.load_bundle_cached(Path::new("/nonexistent/circuit/path")); assert!(result.is_err()); assert!(backend.bundle_cache.lock().unwrap().is_empty()); } #[test] fn resolve_proof_config_rejects_unstamped_bundle() { let bundle = CompiledCircuit { circuit: vec![], witness_solver: vec![], metadata: None, version: None, }; let err = JstproveBackend::resolve_proof_config(&bundle).unwrap_err(); match err { DsperseError::Backend(msg) => { assert!(msg.contains("no stamped proof_config"), "{msg}") } other => panic!("expected Backend error, got {other:?}"), } } } ================================================ FILE: crates/dsperse/src/backend/mod.rs ================================================ pub mod jstprove; pub mod onnx; pub mod traits; pub use traits::ProofBackend; ================================================ FILE: crates/dsperse/src/backend/onnx.rs ================================================ use std::collections::HashMap; use std::path::Path; use std::sync::Arc; use ndarray::IxDyn; use tract_onnx::prelude::*; use tract_onnx::tract_hir::infer::Factoid; use crate::error::{DsperseError, Result}; pub fn coerce_tdim_inputs(inputs: &TVec) -> TVec { inputs .iter() .map(|t| { if t.datum_type() == DatumType::TDim { // Safety: datum_type() == TDim verified by outer condition let view = unsafe { t.as_slice_unchecked::() }; let vals: Vec = view.iter().map(|d| d.to_i64().unwrap_or(0)).collect(); Tensor::from_shape(t.shape(), &vals) .map(|t| t.into_tvalue()) .unwrap_or_else(|_| t.clone()) } else { t.clone() } }) .collect() } pub type NamedOutputs = HashMap, Vec)>; fn load_onnx_model(onnx_path: &Path) -> Result { tract_onnx::onnx() .model_for_path(onnx_path) .map_err(|e| DsperseError::Onnx(format!("load {}: {e}", onnx_path.display()))) } fn resolve_concrete_shape(model: &InferenceModel, input_shape: &[usize]) -> Result> { let model_shape = model .input_fact(0) .ok() .and_then(|f| f.shape.as_concrete_finite().ok().flatten()) .map(|s| s.to_vec()); if input_shape.is_empty() { return model_shape.ok_or_else(|| { DsperseError::Onnx("symbolic input shape — provide explicit shape".into()) }); } if let Some(ref ms) = model_shape { let model_elems: usize = ms.iter().product(); let input_elems: usize = input_shape.iter().product(); if input_shape.len() == 1 && ms.len() > 1 && model_elems == input_elems { tracing::debug!( model_shape = ?ms, provided_shape = ?input_shape, "reshaping flat input to model-declared shape" ); return Ok(ms.clone()); } } Ok(input_shape.to_vec()) } fn resolve_input_datum_type(model: &InferenceModel, idx: usize) -> Result { let fact = model .input_fact(idx) .map_err(|e| DsperseError::Onnx(format!("input fact at index {idx}: {e}")))?; fact.datum_type.concretize().ok_or_else(|| { DsperseError::Onnx(format!( "input fact at index {idx} has no concrete datum type; the model must declare a concrete element type for this input" )) }) } fn optimize_to_runnable( model: InferenceModel, concrete_shape: &[usize], input_dt: DatumType, ) -> Result> { model .with_input_fact(0, InferenceFact::dt_shape(input_dt, concrete_shape)) .map_err(|e| DsperseError::Onnx(format!("set input shape: {e}")))? .into_optimized() .map_err(|e| DsperseError::Onnx(format!("optimize: {e:#}")))? .into_runnable() .map_err(|e| DsperseError::Onnx(format!("make runnable: {e:#}"))) } pub fn run_inference_with_coercion( onnx_path: &Path, input_data: &[f64], input_shape: &[usize], ) -> Result { let model = load_onnx_model(onnx_path)?; let concrete_shape = resolve_concrete_shape(&model, input_shape)?; let input_dt = resolve_input_datum_type(&model, 0)?; if let Ok(plan) = optimize_to_runnable(model, &concrete_shape, input_dt) { let input = build_input_tvalue(input_data, &concrete_shape, input_dt)?; let result = plan .run(tvec![input]) .map_err(|e| DsperseError::Onnx(format!("run: {e:#}")))?; return extract_all_outputs(&result); } tracing::warn!("standard optimization failed; using inference plan with TDim coercion"); let model2 = load_onnx_model(onnx_path)?; let with_shape = model2 .with_input_fact(0, InferenceFact::dt_shape(input_dt, &concrete_shape)) .map_err(|e| DsperseError::Onnx(format!("set input: {e}")))?; let plan = tract_onnx::tract_hir::infer::InferenceSimplePlan::new(std::sync::Arc::new(with_shape)) .map_err(|e| DsperseError::Onnx(format!("inference plan: {e}")))?; let mut state = tract_onnx::tract_core::plan::SimpleState::new(&plan) .map_err(|e| DsperseError::Onnx(format!("state: {e}")))?; let input = build_input_tvalue(input_data, &concrete_shape, input_dt)?; let result = state .run_plan_with_eval(tvec![input], |session, op_state, node, inputs| { let coerced = coerce_tdim_inputs(&inputs); let eval_result = if let Some(st) = op_state { st.eval(session, node.op.as_op(), coerced) } else { node.op.eval(coerced) }; match eval_result { Ok(o) => Ok::<_, TractError>(o), Err(e) => { let Some(first) = inputs.first() else { return Err(e); }; tracing::warn!(node = %node.name, error = %e, "eval failed, using fallback"); let dt = first.datum_type(); let fallback = Tensor::zero_dt(dt, &[1]) .map_err(|alloc_err| { TractError::msg(format!( "node {}: eval failed ({e}); fallback allocation for dtype {dt:?} failed: {alloc_err}", node.name )) })? .into_tvalue(); let n = node.outputs.len().max(1); Ok((0..n).map(|_| fallback.clone()).collect()) } } }) .map_err(|e| DsperseError::Onnx(format!("inference run: {e:#}")))?; extract_all_outputs(&result) } fn extract_all_outputs(result: &[TValue]) -> Result { let mut outputs = NamedOutputs::new(); for (i, tv) in result.iter().enumerate() { let label = format!("output_{i}"); let (data, shape) = tvalue_to_f64(tv, &label)?; outputs.insert(label, (data, shape)); } Ok(outputs) } fn load_runnable( onnx_path: &Path, input_shape: &[usize], ) -> Result<(Arc, Vec, DatumType)> { let model = load_onnx_model(onnx_path)?; let concrete_shape = resolve_concrete_shape(&model, input_shape)?; let input_dt = resolve_input_datum_type(&model, 0)?; let plan = optimize_to_runnable(model, &concrete_shape, input_dt)?; Ok((plan, concrete_shape, input_dt)) } const I64_SAFE_BOUND_F64: f64 = I64_SAFE_BOUND as f64; fn reject_non_finite(v: f64, idx: usize, type_name: &str) -> Result<()> { if !v.is_finite() { return Err(DsperseError::Onnx(format!( "input[{idx}] = {v}: non-finite values are not accepted for {type_name} inputs" ))); } Ok(()) } fn validate_integer_input( v: f64, idx: usize, type_name: &str, type_min: f64, type_max: f64, ) -> Result<()> { reject_non_finite(v, idx, type_name)?; if v.trunc() != v { return Err(DsperseError::Onnx(format!( "input[{idx}] = {v}: fractional component cannot be represented as {type_name}" ))); } if v.abs() > I64_SAFE_BOUND_F64 { return Err(DsperseError::Onnx(format!( "input[{idx}] = {v}: magnitude exceeds IEEE-754 safe integer bound {I64_SAFE_BOUND}" ))); } if v < type_min || v > type_max { return Err(DsperseError::Onnx(format!( "input[{idx}] = {v}: outside representable range [{type_min}, {type_max}] for {type_name}" ))); } Ok(()) } fn build_input_tvalue(input_data: &[f64], shape: &[usize], dt: DatumType) -> Result { let f32_max_f64: f64 = f32::MAX as f64; macro_rules! build_bounded_int { ($t:ty, $name:expr, $min:expr, $max:expr) => {{ let mut data: Vec<$t> = Vec::with_capacity(input_data.len()); for (i, &v) in input_data.iter().enumerate() { validate_integer_input(v, i, $name, $min as f64, $max as f64)?; data.push(v as $t); } tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data) .map(|a| a.into_tvalue()) .map_err(|e| DsperseError::Onnx(format!("input tensor: {e}"))) }}; } if dt == f32::datum_type() { let mut data: Vec = Vec::with_capacity(input_data.len()); for (i, &v) in input_data.iter().enumerate() { reject_non_finite(v, i, "f32")?; if v < -f32_max_f64 || v > f32_max_f64 { return Err(DsperseError::Onnx(format!( "input[{i}] = {v}: magnitude exceeds representable f32 range [-{f32_max_f64}, {f32_max_f64}]" ))); } data.push(v as f32); } tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data) .map(|a| a.into_tvalue()) .map_err(|e| DsperseError::Onnx(format!("input tensor: {e}"))) } else if dt == f64::datum_type() { let mut data: Vec = Vec::with_capacity(input_data.len()); for (i, &v) in input_data.iter().enumerate() { reject_non_finite(v, i, "f64")?; data.push(v); } tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data) .map(|a| a.into_tvalue()) .map_err(|e| DsperseError::Onnx(format!("input tensor: {e}"))) } else if dt == u8::datum_type() { build_bounded_int!(u8, "u8", u8::MIN, u8::MAX) } else if dt == i8::datum_type() { build_bounded_int!(i8, "i8", i8::MIN, i8::MAX) } else if dt == u16::datum_type() { build_bounded_int!(u16, "u16", u16::MIN, u16::MAX) } else if dt == i16::datum_type() { build_bounded_int!(i16, "i16", i16::MIN, i16::MAX) } else if dt == u32::datum_type() { build_bounded_int!(u32, "u32", u32::MIN, u32::MAX) } else if dt == i32::datum_type() { build_bounded_int!(i32, "i32", i32::MIN, i32::MAX) } else if dt == u64::datum_type() { let mut data: Vec = Vec::with_capacity(input_data.len()); for (i, &v) in input_data.iter().enumerate() { validate_integer_input(v, i, "u64", 0.0, I64_SAFE_BOUND_F64)?; data.push(v as u64); } tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data) .map(|a| a.into_tvalue()) .map_err(|e| DsperseError::Onnx(format!("input tensor: {e}"))) } else if dt == i64::datum_type() { let mut data: Vec = Vec::with_capacity(input_data.len()); for (i, &v) in input_data.iter().enumerate() { validate_integer_input(v, i, "i64", -I64_SAFE_BOUND_F64, I64_SAFE_BOUND_F64)?; data.push(v as i64); } tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data) .map(|a| a.into_tvalue()) .map_err(|e| DsperseError::Onnx(format!("input tensor: {e}"))) } else if dt == bool::datum_type() { let mut data: Vec = Vec::with_capacity(input_data.len()); for (i, &v) in input_data.iter().enumerate() { reject_non_finite(v, i, "bool")?; if v != 0.0 && v != 1.0 { return Err(DsperseError::Onnx(format!( "input[{i}] = {v}: bool inputs must be exactly 0 or 1" ))); } data.push(v != 0.0); } tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data) .map(|a| a.into_tvalue()) .map_err(|e| DsperseError::Onnx(format!("input tensor: {e}"))) } else { Err(DsperseError::Onnx(format!( "unsupported input datum type {dt:?}" ))) } } fn run_single( plan: &Arc, input_data: &[f64], shape: &[usize], dt: DatumType, ) -> Result> { let tv = build_input_tvalue(input_data, shape, dt)?; plan.run(tvec!(tv)) .map_err(|e| DsperseError::Onnx(format!("inference: {e}"))) } pub struct WarmModel { plan: Arc, input_shape: Vec, input_dt: DatumType, } impl WarmModel { pub fn load(onnx_path: &Path, input_shape: &[usize]) -> Result { let (plan, input_shape, input_dt) = load_runnable(onnx_path, input_shape)?; Ok(Self { plan, input_shape, input_dt, }) } pub fn run(&self, input_data: &[f64]) -> Result<(Vec, Vec)> { let result = run_single(&self.plan, input_data, &self.input_shape, self.input_dt)?; extract_first_output(&result) } } pub fn run_inference( onnx_path: &Path, input_data: &[f64], input_shape: &[usize], ) -> Result<(Vec, Vec)> { let (plan, concrete_shape, input_dt) = load_runnable(onnx_path, input_shape)?; let result = run_single(&plan, input_data, &concrete_shape, input_dt)?; extract_first_output(&result) } pub fn run_inference_named( onnx_path: &Path, input_data: &[f64], input_shape: &[usize], ) -> Result { let model = load_onnx_model(onnx_path)?; let output_names = collect_output_names(&model); let concrete_shape = resolve_concrete_shape(&model, input_shape)?; let input_dt = resolve_input_datum_type(&model, 0)?; match optimize_to_runnable(model, &concrete_shape, input_dt) { Ok(plan) => { let result = run_single(&plan, input_data, &concrete_shape, input_dt)?; zip_named_outputs(&output_names, &result) } Err(_) => { let mut result = run_inference_with_coercion(onnx_path, input_data, &concrete_shape)?; let mut named = NamedOutputs::new(); for (i, name) in output_names.iter().enumerate() { let key = format!("output_{i}"); if let Some(val) = result.remove(&key) { named.insert(name.clone(), val); } } Ok(named) } } } pub fn run_inference_multi( onnx_path: &Path, inputs: &[(&str, Vec, Vec)], ) -> Result<(Vec, Vec)> { let (result, _) = run_multi_inner(onnx_path, inputs)?; extract_first_output(&result) } pub fn run_inference_multi_named( onnx_path: &Path, inputs: &[(&str, Vec, Vec)], ) -> Result { let (result, output_names) = run_multi_inner(onnx_path, inputs)?; zip_named_outputs(&output_names, &result) } fn run_multi_inner( onnx_path: &Path, inputs: &[(&str, Vec, Vec)], ) -> Result<(TVec, Vec)> { let mut model = load_onnx_model(onnx_path)?; let output_names = collect_output_names(&model); let mut input_by_name: HashMap<&str, usize> = HashMap::with_capacity(inputs.len()); for (idx, (name, _, _)) in inputs.iter().enumerate() { if input_by_name.insert(*name, idx).is_some() { return Err(DsperseError::Onnx(format!( "duplicate provided input name '{name}'" ))); } } let model_input_count = model.inputs.len(); let model_input_names: Vec<(usize, String)> = model .inputs .iter() .enumerate() .map(|(i, outlet)| (i, model.nodes[outlet.node].name.clone())) .collect(); let mut input_order: Vec> = vec![None; model_input_count]; let mut input_dts: Vec> = vec![None; model_input_count]; for (i, name) in &model_input_names { if let Some(&provided_idx) = input_by_name.get(name.as_str()) { let dt = resolve_input_datum_type(&model, *i)?; model = model .with_input_fact(*i, InferenceFact::dt_shape(dt, &inputs[provided_idx].2)) .map_err(|e| DsperseError::Onnx(format!("set input {i} ({name}) shape: {e}")))?; input_order[*i] = Some(provided_idx); input_dts[*i] = Some(dt); } } let unknown_inputs: Vec<&str> = input_by_name .keys() .copied() .filter(|name| !model_input_names.iter().any(|(_, n)| n == *name)) .collect(); if !unknown_inputs.is_empty() { return Err(DsperseError::Onnx(format!( "provided inputs not present in model: {unknown_inputs:?}" ))); } let model = model .into_typed() .map_err(|e| { let unmatched: Vec<_> = input_order .iter() .enumerate() .filter(|(_, v)| v.is_none()) .map(|(i, _)| model_input_names[i].1.as_str()) .collect(); DsperseError::Onnx(format!("type analysis (unmatched: {unmatched:?}): {e}")) })? .into_optimized() .map_err(|e| DsperseError::Onnx(format!("optimize: {e:#}")))? .into_runnable() .map_err(|e| DsperseError::Onnx(format!("make runnable: {e:#}")))?; let mut input_tvs = TVec::new(); for (model_idx, idx) in input_order.iter().enumerate() { let provided_idx = idx.ok_or_else(|| { let name = &model_input_names[model_idx].1; DsperseError::Onnx(format!( "model input {model_idx} ('{name}') not matched to provided tensors" )) })?; let dt = input_dts[model_idx].ok_or_else(|| { let name = &model_input_names[model_idx].1; DsperseError::Onnx(format!( "model input {model_idx} ('{name}') has no resolved datum type" )) })?; let (_, ref data, ref shape) = inputs[provided_idx]; input_tvs.push(build_input_tvalue(data, shape, dt)?); } let result = model .run(input_tvs) .map_err(|e| DsperseError::Onnx(format!("inference: {e}")))?; Ok((result, output_names)) } fn collect_output_names(model: &InferenceModel) -> Vec { model .outputs .iter() .map(|outlet| { model .outlet_label(*outlet) .map(String::from) .unwrap_or_else(|| { format!("{}_output_{}", model.nodes[outlet.node].name, outlet.slot) }) }) .collect() } const I64_SAFE_BOUND: i64 = 9_007_199_254_740_992; fn i64_to_f64_checked(v: i64, label: &str) -> Result { if v.abs() > I64_SAFE_BOUND { return Err(DsperseError::Onnx(format!( "{label}: i64 value {v} exceeds IEEE-754 safe integer bound" ))); } Ok(v as f64) } fn u64_to_f64_checked(v: u64, label: &str) -> Result { if v > I64_SAFE_BOUND as u64 { return Err(DsperseError::Onnx(format!( "{label}: u64 value {v} exceeds IEEE-754 safe integer bound" ))); } Ok(v as f64) } fn tvalue_to_f64(tv: &TValue, label: &str) -> Result<(Vec, Vec)> { let shape = tv.shape().to_vec(); let dt = tv.datum_type(); let data: Vec = if dt == f32::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| f64::from(v)).collect() } else if dt == f64::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().copied().collect() } else if dt == i64::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter() .map(|&v| i64_to_f64_checked(v, label)) .collect::>>()? } else if dt == i32::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| f64::from(v)).collect() } else if dt == u32::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| f64::from(v)).collect() } else if dt == i16::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| f64::from(v)).collect() } else if dt == u16::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| f64::from(v)).collect() } else if dt == i8::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| f64::from(v)).collect() } else if dt == u8::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| f64::from(v)).collect() } else if dt == u64::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter() .map(|&v| u64_to_f64_checked(v, label)) .collect::>>()? } else if dt == bool::datum_type() { let arr = tv .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect() } else if dt.is_tdim() { let casted = tv .cast_to::() .map_err(|e| DsperseError::Onnx(format!("{label}: TDim->i64 cast: {e}")))?; let arr = casted .to_plain_array_view::() .map_err(|e| DsperseError::Onnx(format!("{label}: {e}")))?; arr.iter() .map(|&v| i64_to_f64_checked(v, label)) .collect::>>()? } else { return Err(DsperseError::Onnx(format!( "{label}: unsupported datum type {dt:?}" ))); }; Ok((data, shape)) } fn zip_named_outputs(names: &[String], result: &[TValue]) -> Result { let mut map = HashMap::new(); for (i, tv) in result.iter().enumerate() { let (data, shape) = tvalue_to_f64(tv, &format!("output {i}"))?; let name = names .get(i) .cloned() .unwrap_or_else(|| format!("output_{i}")); if map.insert(name.clone(), (data, shape)).is_some() { return Err(DsperseError::Onnx(format!( "duplicate output name '{name}'" ))); } } Ok(map) } fn extract_first_output(result: &[TValue]) -> Result<(Vec, Vec)> { let output = result .first() .ok_or_else(|| DsperseError::Onnx("no output from model".into()))?; tvalue_to_f64(output, "output tensor") } #[cfg(test)] mod tests { use super::*; const TEST_OPS: &[&str] = &["Conv", "Gemm", "MatMul"]; #[test] fn run_inference_on_sliced_model() { let models_dir = std::path::PathBuf::from(concat!( env!("CARGO_MANIFEST_DIR"), "/../../tests/models/net" )); let model_path = models_dir.join("model.onnx"); assert!( model_path.exists(), "fixture missing: {}", model_path.display() ); let tmp = tempfile::tempdir().unwrap(); let meta = crate::slicer::slice_model(&model_path, Some(tmp.path()), None, TEST_OPS, None) .expect("slice_model failed"); crate::slicer::materializer::ensure_all_slices_materialized(tmp.path(), &meta) .expect("materialization failed"); assert!(!meta.slices.is_empty(), "model produced zero slices"); let first_slice = &meta.slices[0]; let onnx_path = tmp .path() .join(format!("slice_0/payload/{}", first_slice.filename)); assert!( onnx_path.exists(), "sliced ONNX missing: {}", onnx_path.display() ); let input_shape = &first_slice.shape.tensor_shape.input; assert!( !input_shape.is_empty() && !input_shape[0].is_empty(), "empty input shape" ); let shape: Vec = input_shape[0].iter().map(|&d| d.max(1) as usize).collect(); let elem_count: usize = shape.iter().product(); let input_data = vec![0.0f64; elem_count]; let result = run_inference(&onnx_path, &input_data, &shape); assert!(result.is_ok()); let (output_data, output_shape) = result.unwrap(); assert!(!output_data.is_empty()); assert!(!output_shape.is_empty()); } #[test] fn run_inference_nonexistent_model() { let result = run_inference(Path::new("/nonexistent/model.onnx"), &[1.0], &[1]); assert!(result.is_err()); } #[test] fn warm_model_load_nonexistent() { let result = WarmModel::load(Path::new("/nonexistent/model.onnx"), &[1, 1, 28, 28]); assert!(result.is_err()); } #[test] fn warm_model_load_and_run_on_slice() { let models_dir = std::path::PathBuf::from(concat!( env!("CARGO_MANIFEST_DIR"), "/../../tests/models/net" )); let model_path = models_dir.join("model.onnx"); assert!( model_path.exists(), "fixture missing: {}", model_path.display() ); let tmp = tempfile::tempdir().unwrap(); let meta = crate::slicer::slice_model(&model_path, Some(tmp.path()), None, TEST_OPS, None) .expect("slice_model failed"); crate::slicer::materializer::ensure_all_slices_materialized(tmp.path(), &meta) .expect("materialization failed"); assert!(!meta.slices.is_empty(), "model produced zero slices"); let first_slice = &meta.slices[0]; let onnx_path = tmp .path() .join(format!("slice_0/payload/{}", first_slice.filename)); assert!( onnx_path.exists(), "sliced ONNX missing: {}", onnx_path.display() ); let input_shape = &first_slice.shape.tensor_shape.input; assert!( !input_shape.is_empty() && !input_shape[0].is_empty(), "empty input shape" ); let shape: Vec = input_shape[0].iter().map(|&d| d.max(1) as usize).collect(); let elem_count: usize = shape.iter().product(); let warm = WarmModel::load(&onnx_path, &shape).expect("WarmModel::load failed"); let input = vec![0.0f64; elem_count]; let (data1, shape1) = warm.run(&input).unwrap(); let (data2, shape2) = warm.run(&input).unwrap(); assert!(!data1.is_empty()); assert_eq!(shape1, shape2); assert_eq!(data1, data2); } #[test] fn zip_named_outputs_empty() { let result = zip_named_outputs(&[], &[]).unwrap(); assert!(result.is_empty()); } #[test] fn extract_first_output_empty() { let result = extract_first_output(&[]); assert!(result.is_err()); } #[test] fn build_input_tvalue_respects_declared_dtypes() { let shape = [2usize, 3]; let values: Vec = (0..6).map(|v| v as f64).collect(); let tv_f32 = build_input_tvalue(&values, &shape, f32::datum_type()).unwrap(); assert_eq!(tv_f32.datum_type(), f32::datum_type()); assert_eq!(tv_f32.shape(), &shape); let tv_u8 = build_input_tvalue(&values, &shape, u8::datum_type()).unwrap(); assert_eq!(tv_u8.datum_type(), u8::datum_type()); let tv_i64 = build_input_tvalue(&values, &shape, i64::datum_type()).unwrap(); assert_eq!(tv_i64.datum_type(), i64::datum_type()); let bool_vals = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]; let tv_bool = build_input_tvalue(&bool_vals, &shape, bool::datum_type()).unwrap(); assert_eq!(tv_bool.datum_type(), bool::datum_type()); let view = tv_bool.to_plain_array_view::().unwrap(); assert_eq!( view.iter().copied().collect::>(), vec![false, true, false, true, false, true] ); let unsupported = build_input_tvalue(&values, &shape, DatumType::String); assert!(unsupported.is_err()); } #[test] fn build_input_tvalue_rejects_non_finite() { let shape = [3usize]; for dt in [ f32::datum_type(), f64::datum_type(), u8::datum_type(), i64::datum_type(), bool::datum_type(), ] { for bad in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { let err = build_input_tvalue(&[0.0, bad, 1.0], &shape, dt).unwrap_err(); let msg = format!("{err:?}"); assert!( msg.contains("non-finite"), "expected non-finite error for dt={dt:?} val={bad}, got {msg}" ); } } } #[test] fn build_input_tvalue_rejects_fractional_for_integer_dtypes() { let shape = [2usize]; for dt in [ u8::datum_type(), i8::datum_type(), u32::datum_type(), i32::datum_type(), i64::datum_type(), u64::datum_type(), ] { let err = build_input_tvalue(&[0.0, 1.5], &shape, dt).unwrap_err(); let msg = format!("{err:?}"); assert!( msg.contains("fractional"), "expected fractional error for dt={dt:?}, got {msg}" ); } } #[test] fn build_input_tvalue_rejects_out_of_range_for_integer_dtypes() { let shape = [2usize]; let cases: &[(DatumType, f64)] = &[ (u8::datum_type(), 256.0), (u8::datum_type(), -1.0), (i8::datum_type(), 128.0), (i8::datum_type(), -129.0), (u16::datum_type(), -1.0), (i16::datum_type(), 32_768.0), (u32::datum_type(), -1.0), ]; for (dt, bad) in cases.iter().copied() { let err = build_input_tvalue(&[0.0, bad], &shape, dt).unwrap_err(); let msg = format!("{err:?}"); assert!( msg.contains("outside"), "expected range error for dt={dt:?} val={bad}, got {msg}" ); } } #[test] fn safe_integer_bound_is_inclusive_on_both_sides() { let shape = [3usize]; let bound = I64_SAFE_BOUND as f64; build_input_tvalue(&[0.0, bound, -bound], &shape, i64::datum_type()) .expect("i64 accepts +/- I64_SAFE_BOUND"); build_input_tvalue(&[0.0, bound, 1.0], &shape, u64::datum_type()) .expect("u64 accepts I64_SAFE_BOUND"); i64_to_f64_checked(I64_SAFE_BOUND, "i64") .expect("i64_to_f64_checked accepts I64_SAFE_BOUND"); i64_to_f64_checked(-I64_SAFE_BOUND, "i64") .expect("i64_to_f64_checked accepts -I64_SAFE_BOUND"); u64_to_f64_checked(I64_SAFE_BOUND as u64, "u64") .expect("u64_to_f64_checked accepts I64_SAFE_BOUND"); assert!(i64_to_f64_checked(I64_SAFE_BOUND + 1, "i64").is_err()); assert!(u64_to_f64_checked(I64_SAFE_BOUND as u64 + 1, "u64").is_err()); } #[test] fn build_input_tvalue_rejects_i64_above_safe_integer_bound() { let shape = [2usize]; let unsafe_hi = (I64_SAFE_BOUND as f64) + 1024.0; let err = build_input_tvalue(&[0.0, unsafe_hi], &shape, i64::datum_type()).unwrap_err(); let msg = format!("{err:?}"); assert!( msg.contains("safe integer bound"), "expected safe-integer-bound error, got {msg}" ); } #[test] fn build_input_tvalue_rejects_finite_f64_outside_f32_range() { let shape = [2usize]; for bad in [1.0e40_f64, -1.0e40_f64] { assert!(bad.is_finite()); let err = build_input_tvalue(&[0.0, bad], &shape, f32::datum_type()).unwrap_err(); let msg = format!("{err:?}"); assert!( msg.contains("representable f32 range"), "expected f32-range error for val={bad}, got {msg}" ); } let ok = build_input_tvalue( &[0.0, f32::MAX as f64, -(f32::MAX as f64)], &[3], f32::datum_type(), ) .unwrap(); let view = ok.to_plain_array_view::().unwrap(); assert!(view.iter().all(|v| v.is_finite())); } #[test] fn build_input_tvalue_rejects_non_boolean_for_bool_dtype() { let shape = [2usize]; let err = build_input_tvalue(&[0.0, 2.0], &shape, bool::datum_type()).unwrap_err(); let msg = format!("{err:?}"); assert!( msg.contains("bool inputs must be exactly 0 or 1"), "expected strict bool error, got {msg}" ); } fn write_uint8_cast_to_float_model(path: &Path) { use crate::slicer::onnx_proto; let input = onnx_proto::make_tensor_value_info("x", 2, &[3]); // 2 = UINT8 let output = onnx_proto::make_tensor_value_info("y", 1, &[3]); // 1 = FLOAT let cast_to = onnx_proto::make_attribute_int("to", 1); let node = onnx_proto::make_node( "Cast", vec!["x".to_string()], vec!["y".to_string()], vec![cast_to], ); let graph = onnx_proto::make_graph("g", vec![node], vec![input], vec![output], vec![]); let model = onnx_proto::make_model(graph, 13); onnx_proto::save_model(&model, path).unwrap(); } fn write_uint8_identity_model(path: &Path) { use crate::slicer::onnx_proto; let input = onnx_proto::make_tensor_value_info("x", 2, &[3]); // UINT8 let output = onnx_proto::make_tensor_value_info("y", 2, &[3]); // UINT8 let node = onnx_proto::make_node( "Identity", vec!["x".to_string()], vec!["y".to_string()], vec![], ); let graph = onnx_proto::make_graph("g", vec![node], vec![input], vec![output], vec![]); let model = onnx_proto::make_model(graph, 13); onnx_proto::save_model(&model, path).unwrap(); } #[test] fn warm_model_decodes_uint8_output() { let tmp = tempfile::tempdir().unwrap(); let onnx_path = tmp.path().join("u8_identity.onnx"); write_uint8_identity_model(&onnx_path); let shape = [3usize]; let warm = WarmModel::load(&onnx_path, &shape).expect("WarmModel::load"); assert_eq!(warm.input_dt, u8::datum_type()); let (data, out_shape) = warm.run(&[0.0, 128.0, 255.0]).unwrap(); assert_eq!(out_shape, shape.to_vec()); assert_eq!(data, vec![0.0, 128.0, 255.0]); } #[test] fn tvalue_to_f64_covers_added_integer_dtypes() { fn tv_of(values: &[T]) -> TValue { let arr = tract_ndarray::ArrayD::from_shape_vec(IxDyn(&[values.len()]), values.to_vec()) .unwrap(); arr.into_tvalue() } let (d, s) = tvalue_to_f64(&tv_of::(&[0, 255]), "u8").unwrap(); assert_eq!((d, s), (vec![0.0, 255.0], vec![2])); let (d, _) = tvalue_to_f64(&tv_of::(&[-128, 127]), "i8").unwrap(); assert_eq!(d, vec![-128.0, 127.0]); let (d, _) = tvalue_to_f64(&tv_of::(&[0, 65_535]), "u16").unwrap(); assert_eq!(d, vec![0.0, 65_535.0]); let (d, _) = tvalue_to_f64(&tv_of::(&[-32_768, 32_767]), "i16").unwrap(); assert_eq!(d, vec![-32_768.0, 32_767.0]); let (d, _) = tvalue_to_f64(&tv_of::(&[0, u32::MAX]), "u32").unwrap(); assert_eq!(d, vec![0.0, u32::MAX as f64]); let (d, _) = tvalue_to_f64(&tv_of::(&[0, 1_000_000]), "u64").unwrap(); assert_eq!(d, vec![0.0, 1_000_000.0]); let unsafe_hi = (I64_SAFE_BOUND as u64) + 7; let err = tvalue_to_f64(&tv_of::(&[unsafe_hi]), "u64").unwrap_err(); assert!( format!("{err:?}").contains("safe integer bound"), "expected u64 safe-bound error" ); } #[test] fn warm_model_runs_non_f32_input_through_planner() { let tmp = tempfile::tempdir().unwrap(); let onnx_path = tmp.path().join("u8_cast.onnx"); write_uint8_cast_to_float_model(&onnx_path); let shape = [3usize]; let warm = WarmModel::load(&onnx_path, &shape).expect("WarmModel::load"); assert_eq!(warm.input_dt, u8::datum_type()); let (data, out_shape) = warm.run(&[0.0, 42.0, 255.0]).unwrap(); assert_eq!(out_shape, shape.to_vec()); assert_eq!(data, vec![0.0, 42.0, 255.0]); // A second call with a value that can't round-trip through u8 must error // from build_input_tvalue before the planner is invoked. let err = warm.run(&[0.0, 256.0, 0.0]).unwrap_err(); assert!(format!("{err:?}").contains("outside")); } #[test] fn run_inference_multi_honors_per_input_dtype() { let tmp = tempfile::tempdir().unwrap(); let onnx_path = tmp.path().join("u8_cast.onnx"); write_uint8_cast_to_float_model(&onnx_path); let inputs: Vec<(&str, Vec, Vec)> = vec![("x", vec![1.0, 2.0, 3.0], vec![3])]; let out = run_inference_multi_named(&onnx_path, &inputs).unwrap(); let (data, shape) = out.values().next().expect("at least one output"); assert_eq!(shape, &vec![3]); assert_eq!(data, &vec![1.0, 2.0, 3.0]); } #[test] fn resolve_input_datum_type_reads_concrete_model_dtype() { let tmp = tempfile::tempdir().unwrap(); let onnx_path = tmp.path().join("u8_cast.onnx"); write_uint8_cast_to_float_model(&onnx_path); let model = load_onnx_model(&onnx_path).unwrap(); let dt = resolve_input_datum_type(&model, 0).unwrap(); assert_eq!(dt, u8::datum_type()); } } ================================================ FILE: crates/dsperse/src/backend/traits.rs ================================================ use std::path::Path; use crate::error::Result; pub trait ProofBackend: Send + Sync { fn prove(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result>; fn verify(&self, circuit_path: &Path, witness_bytes: &[u8], proof_bytes: &[u8]) -> Result; fn witness_f64( &self, circuit_path: &Path, activations: &[f64], initializers: &[(Vec, Vec)], ) -> Result>; } ================================================ FILE: crates/dsperse/src/cli/mod.rs ================================================ use std::num::NonZeroUsize; use std::path::{Path, PathBuf}; use clap::{Args, Parser, Subcommand}; use crate::backend::jstprove::{JstproveBackend, ProofConfig}; use crate::error::{DsperseError, Result}; use crate::pipeline::{self, RunConfig}; use jstprove_circuits::api::{ProofConfigError, ProofSystemType as ProofSystem}; fn parse_proof_config(value: &str) -> Result { value.parse().map_err(|e: ProofConfigError| { DsperseError::Other(format!("invalid --curve '{value}': {e}")) }) } pub const VERSION: &str = env!("DSPERSE_DISPLAY_VERSION"); #[derive(Parser)] #[command(name = "dsperse", about = "Distributed zkML Toolkit", version = VERSION)] pub struct Cli { #[command(subcommand)] pub command: Commands, #[arg(long, default_value = "warn", global = true)] pub log_level: String, } #[derive(Subcommand)] pub enum Commands { Slice(SliceArgs), Combine(CombineArgs), Compile(CompileArgs), Run(RunArgs), Prove(ProveArgs), Verify(VerifyArgs), Package(PackageArgs), Publish(PublishArgs), #[command(name = "full-run")] FullRun(FullRunArgs), Analyze(AnalyzeArgs), #[command(name = "setup-holographic")] SetupHolographic(SetupHolographicArgs), } pub fn dispatch(command: Commands) -> Result<()> { match command { Commands::Slice(args) => cmd_slice(args), Commands::Combine(args) => cmd_combine(args), Commands::Compile(args) => cmd_compile(args), Commands::Run(args) => cmd_run(args), Commands::Prove(args) => cmd_prove(args), Commands::Verify(args) => cmd_verify(args), Commands::Package(args) => cmd_package(args), Commands::Publish(args) => cmd_publish(args), Commands::FullRun(args) => cmd_full_run(args), Commands::Analyze(args) => cmd_analyze(args), Commands::SetupHolographic(args) => cmd_setup_holographic(args), } } #[derive(Args)] pub struct SliceArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub output_dir: Option, #[arg(long, default_value = "512")] pub tile_size: Option, #[arg( long, default_value = "expander", help = "Proof system backend (expander or remainder)" )] pub proof_system: String, #[arg( long, help = "Comma-separated ONNX op names to compile via the proof backend (default: all supported)" )] pub circuit_ops: Option, #[arg( long, value_delimiter = ',', help = "Concrete input shape as comma-separated dims (e.g. 1,3,560,560). Overrides dynamic dimensions." )] pub input_shape: Option>, } #[derive(Args)] pub struct CombineArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub slices_dir: Option, } #[derive(Args)] pub struct CompileArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub slices_dir: Option, #[arg(long)] pub layers: Option, #[arg(long, default_value = "1")] pub parallel: NonZeroUsize, #[arg( long, default_value_t = true, action = clap::ArgAction::Set, help = "Compile circuits with weights as inputs for shared circuit reuse (default: true)" )] pub weights_as_inputs: bool, #[arg( long, default_value = "expander", help = "Proof system backend (expander or remainder)" )] pub proof_system: String, #[arg( long, help = "Comma-separated ONNX op names to compile via the proof backend (default: all supported)" )] pub circuit_ops: Option, #[arg( long = "proof-config", visible_alias = "curve", default_value = "bn254_raw", help = "Proof config: bn254_raw, goldilocks_basefold, goldilocks_ext2_basefold, goldilocks_ext3_whir, goldilocks_ext4_whir. The --curve alias is retained for backward compatibility and will be removed in a future release." )] pub curve: String, #[arg( long, help = "Skip compilation of slices whose estimated constraint count exceeds this threshold" )] pub skip_compile_over_size: Option, #[arg( long, default_value_t = false, action = clap::ArgAction::Set, help = "Allow the command to exit 0 when individual slices fail to compile. Failed slices fall back to ONNX execution at run / prove time, producing a partial-coverage proof. Off by default so CI surfaces real compile regressions." )] pub allow_onnx_fallback: bool, #[arg( long, default_value_t = false, action = clap::ArgAction::Set, help = "After compiling each slice, run holographic GKR setup and persist the verifying key as vk.bin in the bundle directory. Requires --proof-config goldilocks_ext4_whir." )] pub holographic: bool, } #[derive(Args)] pub struct RunArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub input_file: PathBuf, #[arg(long)] pub run_dir: Option, #[arg(long)] pub slices_dir: Option, #[arg(long, default_value = "1")] pub parallel: NonZeroUsize, #[arg(long)] pub batch: bool, #[arg( long, help = "Path to consumer ONNX with fine-tuned weights to inject at inference time" )] pub weights: Option, #[arg( long, default_value_t = true, action = clap::ArgAction::Set, help = "Run inference on combined monolithic ONNX instead of per-slice execution" )] pub combined: bool, } #[derive(Args)] pub struct ProveArgs { #[arg(long)] pub run_dir: PathBuf, #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub slices_dir: Option, #[arg(long, default_value = "1")] pub parallel: NonZeroUsize, } #[derive(Args)] pub struct VerifyArgs { #[arg(long)] pub run_dir: PathBuf, #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub slices_dir: Option, #[arg(long, default_value = "1")] pub parallel: NonZeroUsize, } #[derive(Args)] pub struct PackageArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub slices_dir: Option, #[arg(long)] pub output_dir: Option, #[arg(long)] pub author: Option, #[arg(long)] pub model_version: Option, #[arg(long)] pub model_name: Option, #[arg(long)] pub timeout: Option, #[arg( long, help = "Finite field curve used as domain separator in content hashes (bn254, goldilocks, goldilocks_basefold, goldilocks_ext2, goldilocks_whir, goldilocks_whir_pq)" )] pub curve: Option, } #[derive(Args)] pub struct PublishArgs { #[arg(long, help = "Package directory containing manifest.msgpack")] pub dir: PathBuf, #[arg(long, help = "Registry base URL")] pub url: String, #[arg(long, env = "REGISTRY_AUTH_TOKEN", hide_env_values = true)] pub auth_token: String, #[arg(long)] pub name: String, #[arg(long, default_value = "")] pub description: String, #[arg(long)] pub author: String, #[arg(long, default_value = "1.0.0")] pub version: String, #[arg(long, default_value = "JSTPROVE")] pub proof_system: String, #[arg(long, default_value = "3600")] pub timeout: u64, #[arg(long, default_value_t = false, help = "Activate model after upload")] pub activate: bool, } #[derive(Args)] pub struct FullRunArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub input_file: Option, #[arg(long)] pub slices_dir: Option, #[arg(long)] pub layers: Option, #[arg( long, default_value_t = true, action = clap::ArgAction::Set, help = "Compile circuits with weights as inputs for shared circuit reuse (default: true)" )] pub weights_as_inputs: bool, #[arg(long, default_value = "1")] pub parallel: NonZeroUsize, #[arg(long)] pub batch: bool, #[arg( long, help = "Path to consumer ONNX with fine-tuned weights to inject at inference time" )] pub weights: Option, #[arg( long, default_value = "expander", help = "Proof system backend (expander or remainder)" )] pub proof_system: String, #[arg( long, help = "Comma-separated ONNX op names to compile via the proof backend (default: all supported)" )] pub circuit_ops: Option, #[arg( long, default_value_t = true, action = clap::ArgAction::Set, help = "Run inference on combined monolithic ONNX instead of per-slice execution" )] pub combined: bool, #[arg( long = "proof-config", visible_alias = "curve", default_value = "bn254_raw", help = "Proof config: bn254_raw, goldilocks_basefold, goldilocks_ext2_basefold, goldilocks_ext3_whir, goldilocks_ext4_whir. The --curve alias is retained for backward compatibility and will be removed in a future release." )] pub curve: String, #[arg( long, help = "Skip compilation of slices whose estimated constraint count exceeds this threshold" )] pub skip_compile_over_size: Option, #[arg( long, default_value_t = false, action = clap::ArgAction::Set, help = "Allow full-run to proceed when individual slices fail to compile. Failed slices fall back to ONNX execution, producing a partial-coverage proof. Off by default so CI surfaces real compile regressions." )] pub allow_onnx_fallback: bool, #[arg( long, default_value_t = false, action = clap::ArgAction::Set, help = "After compiling each slice, run holographic GKR setup and persist the verifying key as vk.bin in the bundle directory. Requires --proof-config goldilocks_ext4_whir." )] pub holographic: bool, } #[derive(Args)] pub struct SetupHolographicArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub slices_dir: Option, #[arg(long, default_value = "1")] pub parallel: NonZeroUsize, #[arg( long, default_value_t = false, action = clap::ArgAction::Set, help = "Re-run setup and overwrite vk.bin even when the bundle already has one" )] pub overwrite: bool, } struct CircuitOps(Vec); impl CircuitOps { fn as_refs(&self) -> Vec<&str> { self.0.iter().map(String::as_str).collect() } } fn resolve_circuit_ops(proof_system_str: &str, circuit_ops: Option<&str>) -> Result { let ps: ProofSystem = proof_system_str .parse() .map_err(|e: jstprove_circuits::api::ProofSystemParseError| { DsperseError::Other(e.to_string()) })?; let supported = ps.supported_ops(); let ops = match circuit_ops { None => supported.iter().map(|s| (*s).to_string()).collect(), Some(spec) => { let requested: Vec = spec .split(',') .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect(); if requested.is_empty() { return Err(DsperseError::Other( "empty --circuit-ops; provide at least one op or omit the flag to use all supported ops".into(), )); } for op in &requested { if !supported.contains(&op.as_str()) { return Err(DsperseError::Other(format!( "op {op:?} is not supported by proof system {ps}. Supported: {supported:?}" ))); } } requested } }; Ok(CircuitOps(ops)) } fn resolve_slices_dir(slices_dir: Option, model_dir: &Path) -> PathBuf { slices_dir.unwrap_or_else(|| model_dir.join("slices")) } pub fn cmd_slice(args: SliceArgs) -> Result<()> { let model_path = args.model_dir.join("model.onnx"); if !model_path.exists() { return Err(DsperseError::Slicer(format!( "model.onnx not found in {}", args.model_dir.display() ))); } let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?; let metadata = crate::slicer::slice_model( &model_path, args.output_dir.as_deref(), args.tile_size, &ops.as_refs(), args.input_shape.as_deref(), )?; tracing::info!(slices = metadata.slices.len(), "slicing complete"); Ok(()) } pub fn cmd_combine(args: CombineArgs) -> Result<()> { let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); let meta = pipeline::runner::load_model_metadata(&slices_dir)?; let path = crate::slicer::combiner::materialize_combined_to_disk(&slices_dir, &meta)?; tracing::info!(path = %path.display(), "combined ONNX materialized"); Ok(()) } pub fn cmd_compile(args: CompileArgs) -> Result<()> { let proof_config = parse_proof_config(&args.curve)?; let backend = JstproveBackend::new(); let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); let layers = args .layers .as_ref() .map(|s| parse_index_spec(s)) .transpose()?; let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?; let report = pipeline::compile_slices( &slices_dir, &backend, proof_config, args.parallel.get(), args.weights_as_inputs, layers.as_deref(), &ops.as_refs(), args.skip_compile_over_size, args.holographic, )?; if args.allow_onnx_fallback { Ok(()) } else { report.ok_if_no_failures().map(|_| ()) } } pub fn cmd_run(args: RunArgs) -> Result<()> { if !args.input_file.is_file() { return Err(DsperseError::Other(format!( "input file not found: {}", args.input_file.display() ))); } let backend = JstproveBackend::new(); let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); let run_dir = args .run_dir .unwrap_or_else(|| args.model_dir.join("run").join(format!("run_{}", run_id()))); let config = RunConfig { parallel: args.parallel.get(), batch: args.batch, weights_onnx: args.weights, combined: args.combined, }; pipeline::run_inference(&slices_dir, &args.input_file, &run_dir, &backend, &config)?; Ok(()) } pub fn cmd_prove(args: ProveArgs) -> Result<()> { let backend = JstproveBackend::new(); let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); pipeline::prove_run(&args.run_dir, &slices_dir, &backend, args.parallel.get())?; Ok(()) } pub fn cmd_verify(args: VerifyArgs) -> Result<()> { let backend = JstproveBackend::new(); let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); pipeline::verify_run(&args.run_dir, &slices_dir, &backend, args.parallel.get())?; Ok(()) } pub fn cmd_package(args: PackageArgs) -> Result<()> { let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); let output_dir = args .output_dir .unwrap_or_else(|| args.model_dir.join("package")); let config = pipeline::packager::PackageConfig { output_dir, author: args.author, model_version: args.model_version, model_name: args.model_name, timeout: args.timeout, curve: args.curve, }; let result = pipeline::packager::package_content_addressed(&slices_dir, &config)?; tracing::info!( components = result.component_count, weight_biases = result.wb_count, total_bytes = result.total_size, manifest = %result.manifest_path.display(), "content-addressed packaging complete" ); Ok(()) } pub fn cmd_publish(args: PublishArgs) -> Result<()> { let config = pipeline::publisher::PublishConfig { api_url: args.url, auth_token: args.auth_token, name: args.name, description: args.description, author: args.author, version: args.version, proof_system: args.proof_system, timeout: args.timeout, activate: args.activate, }; let result = match pipeline::publisher::publish(&args.dir, &config) { Ok(r) => r, Err(e) => { tracing::error!(error = %e, "publish failed"); return Err(e); } }; tracing::info!( model_id = %result.model_id, components_uploaded = result.components_uploaded, components_skipped = result.components_skipped, weights_uploaded = result.weights_uploaded, weights_skipped = result.weights_skipped, "publish complete" ); Ok(()) } pub fn cmd_full_run(args: FullRunArgs) -> Result<()> { let proof_config = parse_proof_config(&args.curve)?; let backend = JstproveBackend::new(); let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); let input_file = args .input_file .unwrap_or_else(|| args.model_dir.join(crate::utils::paths::INPUT_FILE)); if !input_file.is_file() { return Err(DsperseError::Other(format!( "input file not found: {}", input_file.display() ))); } if args.weights.is_some() && !args.weights_as_inputs { return Err(DsperseError::Other( "--weights requires --weights-as-inputs during compilation".into(), )); } let layers = args .layers .as_ref() .map(|s| parse_index_spec(s)) .transpose()?; let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?; tracing::info!("compiling slices"); let report = pipeline::compile_slices( &slices_dir, &backend, proof_config, args.parallel.get(), args.weights_as_inputs, layers.as_deref(), &ops.as_refs(), args.skip_compile_over_size, args.holographic, )?; if !args.allow_onnx_fallback { report.ok_if_no_failures()?; } let run_dir = args.model_dir.join("run").join(format!("run_{}", run_id())); let config = RunConfig { parallel: args.parallel.get(), batch: args.batch, weights_onnx: args.weights, combined: args.combined, }; tracing::info!("running inference"); pipeline::run_inference(&slices_dir, &input_file, &run_dir, &backend, &config)?; tracing::info!("proving"); pipeline::prove_run(&run_dir, &slices_dir, &backend, args.parallel.get())?; tracing::info!("verifying"); pipeline::verify_run(&run_dir, &slices_dir, &backend, args.parallel.get())?; tracing::info!(run_dir = %run_dir.display(), "full run complete"); Ok(()) } pub fn cmd_setup_holographic(args: SetupHolographicArgs) -> Result<()> { let backend = JstproveBackend::new(); let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); let report = pipeline::setup_holographic_for_slices( &slices_dir, &backend, args.parallel.get(), args.overwrite, )?; tracing::info!( processed = report.processed, skipped = report.skipped_already_present, failed = report.failed.len(), "holographic setup complete" ); report.ok_if_no_failures().map(|_| ()) } #[derive(Args)] pub struct AnalyzeArgs { #[arg(long)] pub model_dir: PathBuf, #[arg(long)] pub slices_dir: Option, #[arg( long, default_value = "expander", help = "Proof system backend (expander or remainder)" )] pub proof_system: String, #[arg( long, help = "Comma-separated ONNX op names to compile via the proof backend" )] pub circuit_ops: Option, #[arg( long, help = "Skip slices whose estimated constraint count exceeds this" )] pub skip_compile_over_size: Option, #[arg( long = "proof-config", visible_alias = "curve", default_value = "bn254_raw", help = "Proof config for circuit signature computation" )] pub proof_config: String, #[arg( long, default_value_t = AnalyzeFormat::Table, value_enum, help = "Output format" )] pub format: AnalyzeFormat, } #[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum)] pub enum AnalyzeFormat { Table, Json, } impl std::fmt::Display for AnalyzeFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Table => f.write_str("table"), Self::Json => f.write_str("json"), } } } fn cmd_analyze(args: AnalyzeArgs) -> Result<()> { let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir); let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?; // Validate proof_config through the same parser cmd_compile and // cmd_full_run use so a typo in --proof-config fails fast with a // "unknown proof config 'foo'" message rather than silently // producing signatures under an unintended curve. let proof_config = parse_proof_config(&args.proof_config)?; let proof_config_name = proof_config.to_string(); let reports = pipeline::analyze_slices( &slices_dir, &ops.as_refs(), args.skip_compile_over_size, Some(proof_config_name.as_str()), )?; if matches!(args.format, AnalyzeFormat::Json) { println!( "{}", serde_json::to_string_pretty(&reports) .map_err(|e| DsperseError::Other(e.to_string()))? ); } else { let hdr_ops = "OPS"; println!( "{:<8} {:<10} {:<28} {:<14} {:<6} {:<6} {:<6} {:<12} {hdr_ops}", "SLICE", "BACKEND", "REASON", "EST.CONSTR", "TILED", "CHSPL", "DMSPL", "SIGNATURE" ); println!("{}", "-".repeat(120)); let mut jstprove_count = 0usize; let mut onnx_count = 0usize; let mut missing_count = 0usize; let mut total_constraints: u64 = 0; let mut unique_sigs: std::collections::HashSet = std::collections::HashSet::new(); for r in &reports { let est = r .estimated_constraints .map(|c| format!("{c}")) .unwrap_or_default(); let sig = r .circuit_signature .as_deref() .map(|s| &s[..12.min(s.len())]) .unwrap_or(""); println!( "{:<8} {:<10} {:<28} {:<14} {:<6} {:<6} {:<6} {:<12} {}", r.index, r.backend, r.reason, est, r.tiled, r.channel_split, r.dim_split, sig, r.ops, ); match r.backend.as_str() { "jstprove" => jstprove_count += 1, "onnx" => onnx_count += 1, "missing" => missing_count += 1, other => { tracing::warn!( slice = r.index, backend = other, "analyze: unknown backend classification; not counted" ); } } if let Some(c) = r.estimated_constraints { total_constraints += c; } if let Some(ref s) = r.circuit_signature { unique_sigs.insert(s.clone()); } } println!("{}", "-".repeat(120)); println!( "total: {} slices | jstprove: {} | onnx: {} | missing: {} | unique circuits: {} | total constraints: {}", reports.len(), jstprove_count, onnx_count, missing_count, unique_sigs.len(), total_constraints, ); } Ok(()) } fn parse_index_spec(spec: &str) -> Result> { let mut layers = Vec::new(); for part in spec.split(',') { let part = part.trim(); if part.is_empty() { continue; } if let Some((start, end)) = part.split_once('-') { let s: usize = start.trim().parse().map_err(|_| { DsperseError::Other(format!("invalid index spec range start: {start:?}")) })?; let e: usize = end.trim().parse().map_err(|_| { DsperseError::Other(format!("invalid index spec range end: {end:?}")) })?; if s > e { return Err(DsperseError::Other(format!( "invalid index spec range: start {s} > end {e}" ))); } layers.extend(s..=e); } else { let n: usize = part .parse() .map_err(|_| DsperseError::Other(format!("invalid index spec token: {part:?}")))?; layers.push(n); } } if layers.is_empty() { return Err(DsperseError::Other("empty index spec".into())); } Ok(layers) } fn run_id() -> String { let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default(); let uuid = uuid::Uuid::new_v4(); format!("{}_{}", now.as_secs(), uuid.as_simple()) } #[cfg(test)] mod tests { use super::*; use clap::Parser; #[test] fn parse_index_spec_single() { assert_eq!(parse_index_spec("3").unwrap(), vec![3]); } #[test] fn parse_index_spec_multiple() { assert_eq!(parse_index_spec("1,3,5").unwrap(), vec![1, 3, 5]); } #[test] fn parse_index_spec_range() { assert_eq!(parse_index_spec("2-5").unwrap(), vec![2, 3, 4, 5]); } #[test] fn parse_index_spec_mixed() { assert_eq!(parse_index_spec("0,2-4,7").unwrap(), vec![0, 2, 3, 4, 7]); } #[test] fn parse_index_spec_whitespace_tolerance() { assert_eq!(parse_index_spec(" 1 , 2 - 3 ").unwrap(), vec![1, 2, 3]); } #[test] fn parse_index_spec_empty_rejected() { assert!(parse_index_spec("").is_err()); } #[test] fn parse_index_spec_invalid_token() { assert!(parse_index_spec("abc").is_err()); } #[test] fn parse_index_spec_reversed_range() { assert!(parse_index_spec("5-2").is_err()); } #[test] fn parse_index_spec_trailing_comma() { assert_eq!(parse_index_spec("1,2,").unwrap(), vec![1, 2]); } #[test] fn run_id_format() { let id = run_id(); let parts: Vec<&str> = id.splitn(2, '_').collect(); assert_eq!(parts.len(), 2); assert!(parts[0].parse::().is_ok()); assert_eq!(parts[1].len(), 32); } #[test] fn run_id_unique() { let id1 = run_id(); let id2 = run_id(); assert_ne!(id1, id2); } #[test] fn cli_parse_slice_command() { let cli = Cli::parse_from(["dsperse", "slice", "--model-dir", "/tmp/model"]); assert!(matches!(cli.command, Commands::Slice(_))); } #[test] fn cli_parse_run_command() { let cli = Cli::parse_from([ "dsperse", "run", "--model-dir", "/tmp/model", "--input-file", "/tmp/input.json", ]); assert!(matches!(cli.command, Commands::Run(_))); } #[test] fn cli_log_level_default() { let cli = Cli::parse_from(["dsperse", "slice", "--model-dir", "/tmp"]); assert_eq!(cli.log_level, "warn"); } #[test] fn cli_log_level_override() { let cli = Cli::parse_from([ "dsperse", "--log-level", "debug", "slice", "--model-dir", "/tmp", ]); assert_eq!(cli.log_level, "debug"); } #[test] fn cli_compile_with_layers() { let cli = Cli::parse_from([ "dsperse", "compile", "--model-dir", "/tmp", "--layers", "0,2-4", ]); if let Commands::Compile(args) = cli.command { assert_eq!(args.layers.as_deref(), Some("0,2-4")); } else { panic!("expected Compile"); } } #[test] fn cli_run_parallel() { let cli = Cli::parse_from([ "dsperse", "run", "--model-dir", "/tmp", "--input-file", "/tmp/in.json", "--parallel", "4", ]); if let Commands::Run(args) = cli.command { assert_eq!(args.parallel.get(), 4); } else { panic!("expected Run"); } } #[test] fn cli_slice_with_tile_size() { let cli = Cli::parse_from([ "dsperse", "slice", "--model-dir", "/tmp", "--tile-size", "1024", ]); if let Commands::Slice(args) = cli.command { assert_eq!(args.tile_size, Some(1024)); } else { panic!("expected Slice"); } } #[test] fn cli_parse_combine_command() { let cli = Cli::parse_from(["dsperse", "combine", "--model-dir", "/tmp/model"]); assert!(matches!(cli.command, Commands::Combine(_))); } #[test] fn cli_parse_combine_with_slices_dir() { let cli = Cli::parse_from([ "dsperse", "combine", "--model-dir", "/tmp/model", "--slices-dir", "/tmp/slices", ]); if let Commands::Combine(args) = cli.command { assert_eq!( args.slices_dir, Some(std::path::PathBuf::from("/tmp/slices")) ); } else { panic!("expected Combine"); } } #[test] fn cli_run_combined_default_true() { let cli = Cli::parse_from([ "dsperse", "run", "--model-dir", "/tmp", "--input-file", "/tmp/in.json", ]); if let Commands::Run(args) = cli.command { assert!(args.combined); } else { panic!("expected Run"); } } #[test] fn cli_run_combined_explicit_false() { let cli = Cli::parse_from([ "dsperse", "run", "--model-dir", "/tmp", "--input-file", "/tmp/in.json", "--combined", "false", ]); if let Commands::Run(args) = cli.command { assert!(!args.combined); } else { panic!("expected Run"); } } #[test] fn cli_compile_holographic_default_false() { let cli = Cli::parse_from(["dsperse", "compile", "--model-dir", "/tmp"]); if let Commands::Compile(args) = cli.command { assert!(!args.holographic); } else { panic!("expected Compile"); } } #[test] fn cli_compile_holographic_explicit_true() { let cli = Cli::parse_from([ "dsperse", "compile", "--model-dir", "/tmp", "--holographic", "true", ]); if let Commands::Compile(args) = cli.command { assert!(args.holographic); } else { panic!("expected Compile"); } } #[test] fn cli_full_run_holographic_explicit_true() { let cli = Cli::parse_from([ "dsperse", "full-run", "--model-dir", "/tmp", "--holographic", "true", ]); if let Commands::FullRun(args) = cli.command { assert!(args.holographic); } else { panic!("expected FullRun"); } } #[test] fn cli_setup_holographic_command() { let cli = Cli::parse_from([ "dsperse", "setup-holographic", "--model-dir", "/tmp", "--parallel", "4", ]); if let Commands::SetupHolographic(args) = cli.command { assert_eq!(args.parallel.get(), 4); assert!(!args.overwrite); } else { panic!("expected SetupHolographic"); } } #[test] fn cli_setup_holographic_overwrite() { let cli = Cli::parse_from([ "dsperse", "setup-holographic", "--model-dir", "/tmp", "--overwrite", "true", ]); if let Commands::SetupHolographic(args) = cli.command { assert!(args.overwrite); } else { panic!("expected SetupHolographic"); } } #[test] fn cli_compile_wai_default_true() { let cli = Cli::parse_from(["dsperse", "compile", "--model-dir", "/tmp"]); if let Commands::Compile(args) = cli.command { assert!(args.weights_as_inputs); } else { panic!("expected Compile"); } } #[test] fn cli_compile_wai_explicit_false() { let cli = Cli::parse_from([ "dsperse", "compile", "--model-dir", "/tmp", "--weights-as-inputs", "false", ]); if let Commands::Compile(args) = cli.command { assert!(!args.weights_as_inputs); } else { panic!("expected Compile"); } } #[test] fn resolve_circuit_ops_invalid_proof_system() { let result = resolve_circuit_ops("nonexistent", None); assert!(result.is_err()); } #[test] fn resolve_circuit_ops_unsupported_op() { let result = resolve_circuit_ops("expander", Some("FakeOp")); assert!(result.is_err()); } #[test] fn resolve_circuit_ops_empty_spec_rejected() { let result = resolve_circuit_ops("expander", Some("")); assert!(result.is_err()); } #[test] fn resolve_circuit_ops_whitespace_only_spec_rejected() { let result = resolve_circuit_ops("expander", Some(" , , ")); assert!(result.is_err()); } #[test] fn resolve_circuit_ops_valid_specific_ops() { let supported = ProofSystem::Expander.supported_ops(); assert!(!supported.is_empty()); let first_op = supported[0]; let ops = resolve_circuit_ops("expander", Some(first_op)).unwrap(); assert_eq!(ops.as_refs(), vec![first_op]); } #[test] fn resolve_circuit_ops_none_returns_all() { let ops = resolve_circuit_ops("expander", None).unwrap(); let expected: Vec<&str> = ProofSystem::Expander.supported_ops().to_vec(); assert_eq!(ops.as_refs(), expected); } #[test] fn resolve_slices_dir_custom_path() { let result = resolve_slices_dir(Some(PathBuf::from("/custom")), Path::new("/model")); assert_eq!(result, PathBuf::from("/custom")); } #[test] fn resolve_slices_dir_default_fallback() { let model_dir = Path::new("/model"); let result = resolve_slices_dir(None, model_dir); assert_eq!(result, model_dir.join("slices")); } } ================================================ FILE: crates/dsperse/src/converter.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::Path; use jstprove_circuits::api::{ self, ArchitectureType as Architecture, CircuitParamsType as CircuitParams, WANDBType as WANDB, }; use crate::error::{DsperseError, Result}; pub fn prepare_jstprove_artifacts( onnx_path: &Path, weights_as_inputs: bool, ) -> Result<(CircuitParams, Architecture, WANDB)> { prepare_jstprove_artifacts_filtered(onnx_path, weights_as_inputs, &HashSet::new(), None) } pub fn prepare_jstprove_artifacts_filtered( onnx_path: &Path, weights_as_inputs: bool, exclude_from_wai: &HashSet, traced_shapes: Option<&HashMap>>, ) -> Result<(CircuitParams, Architecture, WANDB)> { let meta = match traced_shapes { Some(shapes) => { let converted: HashMap> = shapes .iter() .map(|(k, v)| { ( k.clone(), v.iter() .map(|&d| if d < 0 { 1 } else { d as usize }) .collect(), ) }) .collect(); api::generate_metadata_with_shapes(onnx_path, converted) } None => api::generate_metadata(onnx_path), } .map_err(|e| DsperseError::Pipeline(format!("ONNX metadata generation: {e:#}")))?; let mut params = meta.circuit_params; if weights_as_inputs { api::populate_wai_inputs(&mut params, &meta.wandb, exclude_from_wai) .map_err(|e| DsperseError::Pipeline(format!("WAI input population: {e}")))?; } Ok((params, meta.architecture, meta.wandb)) } #[cfg(test)] mod tests { use super::*; #[test] fn prepare_jstprove_artifacts_nonexistent_model() { let result = prepare_jstprove_artifacts(Path::new("/nonexistent.onnx"), false); assert!(result.is_err()); } #[test] fn prepare_jstprove_artifacts_with_weights_as_inputs() { let result = prepare_jstprove_artifacts(Path::new("/nonexistent.onnx"), true); assert!(result.is_err()); } } ================================================ FILE: crates/dsperse/src/error.rs ================================================ use std::path::PathBuf; pub type Result = std::result::Result; #[derive(Debug, thiserror::Error)] pub enum DsperseError { #[error("I/O error at {}: {source}", .path.file_name().and_then(|n| n.to_str()).unwrap_or(""))] Io { source: std::io::Error, path: PathBuf, }, #[error("msgpack encode error: {0}")] MsgpackEncode(#[from] rmp_serde::encode::Error), #[error("msgpack decode error: {0}")] MsgpackDecode(#[from] rmp_serde::decode::Error), #[error("ONNX error: {0}")] Onnx(String), #[error("backend error: {0}")] Backend(String), #[error("slicer error: {0}")] Slicer(String), #[error("archive error: {0}")] Archive(String), #[error("metadata error: {0}")] Metadata(String), #[error("pipeline error: {0}")] Pipeline(String), #[error("{0}")] Other(String), } impl DsperseError { pub fn io(source: std::io::Error, path: impl Into) -> Self { Self::Io { source, path: path.into(), } } } ================================================ FILE: crates/dsperse/src/lib.rs ================================================ pub mod backend; pub mod cli; pub mod converter; pub mod error; pub mod pipeline; pub mod schema; pub mod slicer; pub mod utils; pub mod version; #[cfg(feature = "python")] mod python; ================================================ FILE: crates/dsperse/src/main.rs ================================================ use clap::Parser; use tracing_subscriber::EnvFilter; use dsperse::cli; fn main() { let parsed = cli::Cli::parse(); tracing_subscriber::fmt() .with_env_filter( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&parsed.log_level)), ) .init(); eprintln!("dsperse {}", cli::VERSION); if let Err(e) = cli::dispatch(parsed.command) { tracing::error!("{e}"); std::process::exit(1); } } ================================================ FILE: crates/dsperse/src/pipeline/channel_split.rs ================================================ use std::collections::HashMap; use std::path::Path; use ndarray::{Array4, ArrayD, s}; use super::runner::{generate_wai_witness, resolve_circuit_path_optional, run_onnx_inference}; use super::tensor_store::TensorStore; use crate::backend::jstprove::JstproveBackend; use crate::error::{DsperseError, Result}; use crate::schema::execution::{ExecutionInfo, ExecutionMethod}; use crate::schema::tiling::{ChannelGroupInfo, ChannelSplitInfo}; use crate::slicer::onnx_proto::TensorProto; use crate::utils::io::read_msgpack; use crate::utils::paths::resolve_relative_path; pub(crate) fn reshape_channel_split_output( arr: ArrayD, target_shape: Option<&[i64]>, ) -> Result> { let Some(raw) = target_shape else { return Ok(arr); }; let target: Vec = raw .iter() .map(|&d| { usize::try_from(d).map_err(|_| { DsperseError::Pipeline(format!("negative dimension {d} in output_shape")) }) }) .collect::>>()?; if arr.shape() == target.as_slice() { return Ok(arr); } let actual_shape: Vec = arr.shape().to_vec(); let actual_elems: usize = actual_shape.iter().product(); let target_elems: usize = target.iter().product(); if actual_elems != target_elems { return Err(DsperseError::Pipeline(format!( "channel_split output element count mismatch: \ actual {actual_elems} (shape {actual_shape:?}) vs target {target_elems} (shape {target:?})" ))); } arr.into_shape_with_order(ndarray::IxDyn(&target)) .map_err(|e| { DsperseError::Pipeline(format!( "channel_split output reshape from {actual_shape:?} to {target:?}: {e}", )) }) } #[allow(clippy::too_many_arguments)] pub(crate) fn execute_channel_split( slices_dir: &Path, slice_run_dir: &Path, slice_id: &str, cs: &ChannelSplitInfo, target_shape: Option<&[i64]>, tensor_cache: &TensorStore, backend: &JstproveBackend, donor_init_map: Option<&HashMap>, ) -> Result { let input_arr = tensor_cache.get(&cs.input_name)?.clone(); let (input_4d, n, h) = if input_arr.ndim() == 4 { let s = input_arr.shape(); let n = s[0]; if n != 1 { return Err(DsperseError::Pipeline(format!( "channel split: batch size {n} not supported, expected 1" ))); } let h = s[2]; let arr = Array4::from_shape_vec((n, s[1], s[2], s[3]), input_arr.iter().copied().collect()) .map_err(|e| DsperseError::Pipeline(format!("channel split reshape: {e}")))?; (arr, n, h) } else { let n = 1usize; let input_flat: Vec = input_arr.iter().copied().collect(); let total_elements = input_flat.len(); let nc = n * cs.c_in; if nc > 0 && !total_elements.is_multiple_of(nc) { return Err(DsperseError::Pipeline(format!( "channel split reshape: total_elements {total_elements} not divisible by n*c_in ({nc})" ))); } let spatial = if cs.c_in > 0 && total_elements > 0 { total_elements / nc } else { cs.h * cs.w }; let h = cs.h.max(1); if spatial > 0 && h > 0 && spatial % h != 0 { return Err(DsperseError::Pipeline(format!( "channel split reshape: spatial {spatial} not divisible by h={h}" ))); } let w = if spatial > 0 && h > 0 { spatial / h } else { cs.w.max(1) }; let arr = Array4::from_shape_vec((n, cs.c_in, h, w), input_flat) .map_err(|e| DsperseError::Pipeline(format!("channel split reshape: {e}")))?; (arr, n, h) }; let mut accumulated: Option> = None; tracing::info!( slice = %slice_id, num_groups = cs.groups.len(), "channel split execution" ); let n_channels = input_4d.shape()[1]; for group in &cs.groups { if group.c_end > n_channels || group.c_start > group.c_end { return Err(DsperseError::Pipeline(format!( "channel group {} bounds [{}, {}) exceed channel dimension {}", group.group_idx, group.c_start, group.c_end, n_channels ))); } let group_input = input_4d .slice(s![.., group.c_start..group.c_end, .., ..]) .to_owned(); let group_input_dyn = group_input.into_dyn(); let group_dir = slice_run_dir.join(format!("group_{}", group.group_idx)); std::fs::create_dir_all(&group_dir).map_err(|e| DsperseError::io(e, &group_dir))?; let group_output = execute_channel_group( slices_dir, &group_dir, group, &group_input_dyn, backend, donor_init_map, )?; let group_4d = if group_output.ndim() == 4 { let s = group_output.shape(); Array4::from_shape_vec( (s[0], s[1], s[2], s[3]), group_output.iter().copied().collect(), ) .map_err(|e| DsperseError::Pipeline(format!("group output reshape: {e}")))? } else { let group_flat: Vec = group_output.iter().copied().collect(); let (out_h, out_w) = if cs.out_h > 0 && cs.out_w > 0 { (cs.out_h, cs.out_w) } else if cs.c_out > 0 { let out_spatial = group_flat.len() / (n * cs.c_out); if h > 0 && out_spatial > 0 && out_spatial.is_multiple_of(h) { (h, out_spatial / h) } else { return Err(DsperseError::Pipeline(format!( "cannot determine spatial layout for channel_split output: {} elements, c_out={}, set out_h/out_w in metadata", group_flat.len(), cs.c_out ))); } } else { return Err(DsperseError::Pipeline("channel split c_out is 0".into())); }; if n * cs.c_out * out_h * out_w != group_flat.len() { return Err(DsperseError::Pipeline(format!( "group output reshape mismatch: expected {} elements (n={}, c_out={}, h={}, w={}), got {}", n * cs.c_out * out_h * out_w, n, cs.c_out, out_h, out_w, group_flat.len() ))); } Array4::from_shape_vec((n, cs.c_out, out_h, out_w), group_flat) .map_err(|e| DsperseError::Pipeline(format!("group output reshape: {e}")))? }; accumulated = Some(match accumulated { Some(acc) => { if acc.shape() != group_4d.shape() { return Err(DsperseError::Pipeline(format!( "channel group {} shape {:?} does not match accumulator shape {:?}", group.group_idx, group_4d.shape(), acc.shape() ))); } acc + &group_4d } None => group_4d, }); } if let Some(ref bias_path_str) = cs.bias_path { let bias_file = resolve_relative_path(slices_dir, bias_path_str)?; if !bias_file.exists() { return Err(DsperseError::Pipeline(format!( "configured bias file not found: {} (bias_path={bias_path_str})", bias_file.display() ))); } let bias_data = read_msgpack(&bias_file)?; let bias_flat = crate::utils::io::flatten_nested_list(&bias_data); if bias_flat.len() != cs.c_out { return Err(DsperseError::Pipeline(format!( "bias length {} does not match c_out {}", bias_flat.len(), cs.c_out ))); } if let Some(ref mut acc) = accumulated { for ((_, c, _, _), val) in acc.indexed_iter_mut() { *val += bias_flat[c]; } } } let output = match accumulated { Some(acc) => reshape_channel_split_output(acc.into_dyn(), target_shape)?, None => { return Err(DsperseError::Pipeline(format!( "channel_split produced no output for '{}'", cs.output_name ))); } }; Ok(crate::schema::execution::StrategyOutput { info: ExecutionInfo { method: ExecutionMethod::ChannelSplit, success: true, error: None, witness_file: None, tile_exec_infos: Vec::new(), }, outputs: vec![(cs.output_name.clone(), output)], }) } fn execute_channel_group( slices_dir: &Path, group_dir: &Path, group: &ChannelGroupInfo, group_input: &ArrayD, backend: &JstproveBackend, donor_init_map: Option<&HashMap>, ) -> Result> { let onnx_path = resolve_relative_path(slices_dir, &group.path)?; let patched_onnx = if let Some(map) = donor_init_map { Some(crate::slicer::onnx_proto::build_patched_onnx( &onnx_path, map, )?) } else { None }; let effective_onnx = patched_onnx .as_ref() .map_or(onnx_path.as_path(), |t| t.path()); if let Some(circuit_path) = resolve_circuit_path_optional(slices_dir, group.jstprove_circuit_path.as_deref())? { let params = backend.load_params(&circuit_path)?; let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs); if donor_init_map.is_some() && !is_wai { return Err(DsperseError::Pipeline(format!( "group_{}: consumer weights require circuits compiled with --weights-as-inputs", group.group_idx ))); } let output_tensor = run_onnx_inference(effective_onnx, group_input)?; let flat: Vec = group_input.iter().copied().collect(); let witness_bytes = if is_wai { generate_wai_witness( backend, &circuit_path, &onnx_path, donor_init_map, params.as_ref().unwrap(), &flat, )? } else { backend.witness_f64(&circuit_path, &flat, &[])? }; let witness_path = group_dir.join(crate::utils::paths::WITNESS_FILE); std::fs::write(&witness_path, &witness_bytes) .map_err(|e| DsperseError::io(e, &witness_path))?; Ok(output_tensor) } else { run_onnx_inference(effective_onnx, group_input) } } ================================================ FILE: crates/dsperse/src/pipeline/combined.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use ndarray::{ArrayD, IxDyn}; use super::incremental::SliceWork; use super::runner::{build_execution_chain, build_run_metadata, load_model_metadata}; use super::strategy::ExecutionStrategy; use super::tensor_store::TensorStore; use crate::backend::onnx::NamedOutputs; use crate::error::{DsperseError, Result}; use crate::schema::execution::{ExecutionChain, RunMetadata}; use crate::schema::metadata::ModelMetadata; pub struct CombinedRun { tensor_cache: TensorStore, model_meta: ModelMetadata, run_meta: RunMetadata, execution_chain: ExecutionChain, slices_dir: PathBuf, pending_slices: HashSet, failed_slices: HashSet, } impl CombinedRun { pub fn new(slices_dir: &Path, input: ArrayD) -> Result { let model_meta = load_model_metadata(slices_dir)?; let combined_path = crate::slicer::combiner::ensure_combined_materialized(slices_dir, &model_meta)?; crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, &model_meta)?; let first_slice = model_meta .slices .first() .ok_or_else(|| DsperseError::Pipeline("model has no slices".into()))?; let declared_inputs = &first_slice.dependencies.filtered_inputs; if declared_inputs.is_empty() { return Err(DsperseError::Pipeline( "first slice has no input dependency".into(), )); } let named_outputs = run_combined_onnx(&combined_path, &input, declared_inputs)?; let mut tensor_cache = TensorStore::new(); for (name, (data, shape)) in &named_outputs { let arr = ArrayD::from_shape_vec(IxDyn(shape), data.clone()) .map_err(|e| DsperseError::Pipeline(format!("output reshape '{name}': {e}")))?; tensor_cache.put(name.clone(), arr); } for name in declared_inputs { if !tensor_cache.contains(name) { tensor_cache.put(name.clone(), input.clone()); } } // Seed the tensor_cache with any initializer-backed tensor // the slice metadata references. The slicer's constant- // folding passes can turn intermediate tensors (e.g. a // Transpose over a constant) into initializers in the // transformed graph, while leaving downstream slice // metadata pointing at the original tensor name. ORT // does not emit those names among its named outputs (they // are not declared as graph outputs of combined.onnx and // have no producing node), so without this seed the // subsequent `tensor_cache.get` in `all_circuit_work` fails // with `tensor '' not found in store` and the whole // run aborts before a single DSlice gets dispatched. seed_tensor_cache_from_initializers(&combined_path, &model_meta, &mut tensor_cache)?; let chain = build_execution_chain(&model_meta, slices_dir)?; let run_meta = build_run_metadata(&model_meta, slices_dir, &chain)?; let mut pending_slices = HashSet::new(); for slice in &model_meta.slices { let slice_id = format!("slice_{}", slice.index); let node = chain.nodes.get(&slice_id).ok_or_else(|| { DsperseError::Pipeline(format!("execution chain missing node for {slice_id}")) })?; if node.use_circuit { pending_slices.insert(slice_id); } } tracing::info!( total_slices = model_meta.slices.len(), circuit_slices = pending_slices.len(), cached_tensors = tensor_cache.len(), "combined inference complete, all circuit work queued" ); Ok(Self { tensor_cache, model_meta, run_meta, execution_chain: chain, slices_dir: slices_dir.to_path_buf(), pending_slices, failed_slices: HashSet::new(), }) } pub fn all_circuit_work(&self) -> Result> { let mut work_items = Vec::with_capacity(self.pending_slices.len()); for slice in &self.model_meta.slices { let slice_id = format!("slice_{}", slice.index); if !self.pending_slices.contains(&slice_id) { continue; } let node = self.execution_chain.nodes.get(&slice_id).ok_or_else(|| { DsperseError::Pipeline(format!("execution chain missing node for {slice_id}")) })?; let meta = self.run_meta.slices.get(&slice_id).ok_or_else(|| { DsperseError::Pipeline(format!("run metadata missing slice {slice_id}")) })?; let strategy = ExecutionStrategy::from_metadata(meta, node.use_circuit)?; let (input, named_inputs) = match strategy { ExecutionStrategy::ChannelSplit(cs) => { let t = self.tensor_cache.get(&cs.input_name)?.clone(); (t, Vec::new()) } ExecutionStrategy::DimSplit(ds) => { let t = self.tensor_cache.get(&ds.input_name)?.clone(); (t, Vec::new()) } ExecutionStrategy::Tiled(tiling) => { let t = self.tensor_cache.get(&tiling.input_name)?.clone(); (t, Vec::new()) } ExecutionStrategy::Single { .. } => { let filtered = &meta.dependencies.filtered_inputs; let mut named = Vec::with_capacity(filtered.len()); let mut flat_elems: Vec = Vec::new(); for name in filtered { let arr = self.tensor_cache.get(name)?; named.push((name.clone(), arr.clone())); flat_elems.extend(arr.iter()); } let concatenated = ndarray::ArrayD::from_shape_vec( ndarray::IxDyn(&[flat_elems.len()]), flat_elems, ) .map_err(|e| DsperseError::Pipeline(format!("flatten inputs: {e}")))?; (concatenated, named) } }; work_items.push(SliceWork { slice_id, input, named_inputs, backend: node.backend, use_circuit: node.use_circuit, tiling: meta.tiling.clone(), channel_split: meta.channel_split.clone(), circuit_path: node.circuit_path.clone(), onnx_path: node.onnx_path.clone(), slice_meta: meta.clone(), }); } Ok(work_items) } pub fn mark_slice_done(&mut self, slice_id: &str) -> bool { self.pending_slices.remove(slice_id) } pub fn mark_slice_failed(&mut self, slice_id: &str) -> bool { let was_pending = self.pending_slices.remove(slice_id); if was_pending { self.failed_slices.insert(slice_id.to_string()); } was_pending } pub fn is_slice_failed(&self, slice_id: &str) -> bool { self.failed_slices.contains(slice_id) } pub fn failed_count(&self) -> usize { self.failed_slices.len() } pub fn is_complete(&self) -> bool { self.pending_slices.is_empty() } pub fn model_meta(&self) -> &ModelMetadata { &self.model_meta } pub fn final_output(&self) -> Option<&ArrayD> { let last_slice = self.model_meta.slices.last()?; let slice_id = format!("slice_{}", last_slice.index); let meta = self.run_meta.slices.get(&slice_id)?; let strategy = ExecutionStrategy::from_metadata(meta, false).ok()?; match strategy.output_name() { Some(name) => self.tensor_cache.try_get(name), None => { let output_name = meta.dependencies.output.first()?; self.tensor_cache.try_get(output_name) } } } pub fn expected_slice_outputs(&self, slice_id: &str) -> Option> { let meta = self.run_meta.slices.get(slice_id)?; let output_names = &meta.dependencies.output; self.outputs_for_names(output_names) } pub fn outputs_for_names(&self, names: &[String]) -> Option> { let mut flat = Vec::new(); for name in names { let tensor = self.tensor_cache.try_get(name)?; flat.extend(tensor.iter()); } if flat.is_empty() { None } else { Some(flat) } } pub fn slice_tile_counts(&self) -> (usize, usize, HashMap) { let total_slices = self.model_meta.slices.len(); let mut map = HashMap::with_capacity(total_slices); let mut total_tiles = 0usize; for s in &self.model_meta.slices { let tiles = s.tiling.as_ref().map(|t| t.num_tiles).unwrap_or(1); map.insert(format!("slice_{}", s.index), tiles); total_tiles += tiles; } (total_slices, total_tiles, map) } pub fn slices_dir(&self) -> &Path { &self.slices_dir } pub fn pending_count(&self) -> usize { self.pending_slices.len() } } fn run_combined_onnx( combined_path: &Path, input: &ArrayD, declared_inputs: &[String], ) -> Result { if declared_inputs.len() == 1 { let input_flat: Vec = input.iter().copied().collect(); let input_shape = input.shape(); crate::backend::onnx::run_inference_named(combined_path, &input_flat, input_shape) } else { Err(DsperseError::Pipeline(format!( "combined mode requires single input, got {}", declared_inputs.len() ))) } } /// Populate `tensor_cache` with any combined-graph initializer /// whose name appears in slice metadata as a `filtered_input` or a /// declared `output`. Without this, a slice that depends on a /// constant-folded tensor (one the slicer turned from a node /// output into an initializer) would fail at the /// `tensor_cache.get(name)` call in `all_circuit_work` even though /// the value is right there in the combined ONNX. fn seed_tensor_cache_from_initializers( combined_path: &Path, model_meta: &ModelMetadata, tensor_cache: &mut TensorStore, ) -> Result<()> { let needed: HashSet<&str> = model_meta .slices .iter() .flat_map(|s| { s.dependencies .filtered_inputs .iter() .chain(s.dependencies.output.iter()) }) .map(String::as_str) .collect(); if needed.is_empty() { return Ok(()); } let model = crate::slicer::onnx_proto::load_model(combined_path)?; let graph = match &model.graph { Some(g) => g, None => return Ok(()), }; let mut seeded = 0usize; for init in &graph.initializer { if !needed.contains(init.name.as_str()) { continue; } if tensor_cache.contains(&init.name) { continue; } // Negative dims would silently wrap to huge positive // values via `as usize`; reject up front so a malformed // initialiser surfaces an error here instead of // allocating a multi-petabyte array below. let shape: Vec = match init .dims .iter() .map(|&d| usize::try_from(d)) .collect::, _>>() { Ok(s) => s, Err(e) => { tracing::debug!( name = %init.name, dims = ?init.dims, error = %e, "skipping initializer-backed slice tensor: invalid (negative) dimension" ); continue; } }; // Use checked_mul so an arithmetic overflow surfaces as a // skip (and the slice executor downstream produces a // clearer error if it actually needed the value), instead // of wrapping silently and mis-comparing against // `data.len()`. let expected: Option = shape.iter().try_fold(1usize, |acc, &d| acc.checked_mul(d)); let Some(expected) = expected else { tracing::debug!( name = %init.name, dims = ?init.dims, "skipping initializer-backed slice tensor: shape product overflowed usize" ); continue; }; // Decode straight to f64 so DOUBLE / INT64 initialisers // keep their full precision -- the previous f32-then-widen // chain truncated DOUBLE mantissas and silently lost // precision on INT64 magnitudes outside f32's exact range. let data: Vec = crate::slicer::onnx_proto::tensor_to_f64(init); if data.len() != expected { // Skip rather than fail: an initialiser whose declared // shape doesn't match its element count can still be // useful elsewhere (some quantised tensors store packed // bytes), but we cannot reshape it into ArrayD // here without guessing. Leave it to the slice ONNX // executor to surface a clearer error if it actually // needs the value. tracing::debug!( name = %init.name, declared_shape = ?shape, declared_elements = expected, actual_elements = data.len(), "skipping initializer-backed slice tensor: declared shape != element count" ); continue; } let arr = ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| { DsperseError::Pipeline(format!( "seed initializer-backed tensor '{}' from combined.onnx: {e}", init.name )) })?; tensor_cache.put(init.name.clone(), arr); seeded += 1; } if seeded > 0 { tracing::info!( seeded, "seeded tensor_cache with constant-folded slice-input initializers" ); } Ok(()) } ================================================ FILE: crates/dsperse/src/pipeline/compiler.rs ================================================ use std::collections::HashMap; use std::path::{Path, PathBuf}; use rayon::prelude::*; use crate::backend::jstprove::JstproveBackend; use crate::converter; use crate::error::{DsperseError, Result}; use crate::schema::metadata::ModelMetadata; use crate::slicer::autotiler::estimate_slice_constraints; use crate::slicer::onnx_proto; use crate::utils::paths::{find_metadata_path, slice_dir_path}; type CircuitCache = std::sync::Mutex>; enum CompileOutcome { Compiled, CompiledChannelSplit { group_circuits: Vec<(usize, String)>, }, CompiledDimSplit, Skipped, SkippedOverSize { estimated: u64, threshold: u64, }, } /// Summary of a compile_slices invocation. The pass returns Ok /// even when individual slice compilations fail, so callers must /// inspect `failed` to decide whether to proceed (e.g. allow /// partial-coverage ONNX fallback) or abort. Keeping the /// compiled count explicit lets the CLI / analyze command /// report a structured summary instead of inferring success from /// log lines. #[derive(Debug, Default)] pub struct CompileReport { pub compiled: usize, pub failed: Vec<(usize, DsperseError)>, } impl CompileReport { /// Return Ok(self) when every slice compiled cleanly. Otherwise /// return a generic Pipeline error; callers layer their own /// actionable guidance on top (the CLI mentions its /// --allow-onnx-fallback flag, the Python binding mentions the /// `allow_onnx_fallback` keyword). Keeping the library message /// surface-agnostic avoids leaking CLI conventions into the /// Python / Rust API error stream. pub fn ok_if_no_failures(self) -> Result { if self.failed.is_empty() { Ok(self) } else { Err(DsperseError::Pipeline(format!( "compile_slices: {} slice(s) failed to compile; the caller must opt in to partial coverage before proceeding", self.failed.len() ))) } } } /// Backfill split metadata fields that only become resolvable after /// slicing (channel_split.groups populated from disk, /// dim_split.template_path inferred from the materialized template /// ONNX), and strip dim_split entries whose template could not be /// materialized. Called from both compile_slices and analyze_slices /// so the two classifications agree on what actually counts as a /// channel- or dim-split slice. Persists the normalised metadata /// back to disk when any field changes. fn normalize_split_metadata( slices_dir: &Path, meta_path: &Path, metadata: &mut ModelMetadata, ) -> Result<()> { if metadata.original_model_path.is_some() { crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, metadata)?; } let mut metadata_dirty = false; for slice in &mut metadata.slices { if let Some(ref mut cs) = slice.channel_split && cs.groups.is_empty() { let populated = populate_channel_split_groups(slices_dir, slice.index, cs)?; if populated { metadata_dirty = true; } } if let Some(ref mut ds) = slice.dim_split && ds.template_path.is_none() { let tmpl_rel = format!("slice_{}/payload/dim_template.onnx", slice.index); if slices_dir.join(&tmpl_rel).exists() { ds.template_path = Some(tmpl_rel); metadata_dirty = true; } } } // Strip dim_split metadata from slices where template creation // failed (axis-separability rejection, unsupported split kind). // Leaving stale dim_split entries in the metadata causes // downstream runners and the packager to emit bundles that fail // at the strategy validation stage ("dim_split present but // template_path is missing"). for slice in &mut metadata.slices { if slice .dim_split .as_ref() .is_some_and(|ds| ds.template_path.is_none()) { tracing::info!( slice = slice.index, "stripping dim_split metadata (no template materialized)" ); slice.dim_split = None; metadata_dirty = true; } } if metadata_dirty { metadata.save(meta_path)?; tracing::info!("persisted materialized split groups to metadata"); } Ok(()) } #[allow(clippy::too_many_arguments)] pub fn compile_slices( slices_dir: &Path, backend: &JstproveBackend, proof_config: jstprove_circuits::api::ProofConfigType, parallel: usize, weights_as_inputs: bool, layers: Option<&[usize]>, jstprove_ops: &[&str], skip_compile_over_size: Option, holographic: bool, ) -> Result { if holographic && proof_config != jstprove_circuits::api::ProofConfigType::GoldilocksExt4Whir { return Err(DsperseError::Pipeline(format!( "--holographic requires --proof-config goldilocks_ext4_whir; got {proof_config}" ))); } let meta_path = find_metadata_path(slices_dir).ok_or_else(|| { DsperseError::Metadata(format!( "no {} found in slices directory", crate::utils::paths::METADATA_FILE )) })?; let mut metadata = ModelMetadata::load(&meta_path)?; normalize_split_metadata(slices_dir, &meta_path, &mut metadata)?; let slices: Vec<_> = metadata .slices .iter() .filter(|s| layers.is_none_or(|l| l.contains(&s.index))) .cloned() .collect(); tracing::info!(total = slices.len(), "compiling slices"); let exclude_from_wai: std::collections::HashSet = metadata.folded_constant_names.iter().cloned().collect(); let traced_shapes = metadata.traced_shapes.clone(); let traced_ref = traced_shapes.as_ref(); let pool = rayon::ThreadPoolBuilder::new() .num_threads(parallel) .build() .map_err(|e| DsperseError::Pipeline(format!("thread pool: {e}")))?; let compiled_count = std::sync::atomic::AtomicUsize::new(0); let meta_mutex = std::sync::Mutex::new((&mut metadata, false)); let errors: std::sync::Mutex> = std::sync::Mutex::new(Vec::new()); let circuit_cache: CircuitCache = std::sync::Mutex::new(HashMap::new()); pool.install(|| { slices.par_iter().for_each(|slice| { let r = compile_single_slice( slices_dir, slice, backend, proof_config, weights_as_inputs, jstprove_ops, &exclude_from_wai, skip_compile_over_size, &circuit_cache, traced_ref, holographic, ); match r { Ok(CompileOutcome::Compiled) => { let count = compiled_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; tracing::info!(slice = slice.index, count, "compiled"); } Ok(CompileOutcome::CompiledChannelSplit { group_circuits }) => { let count = compiled_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; tracing::info!( slice = slice.index, groups = group_circuits.len(), count, "compiled channel split groups" ); let mut guard = meta_mutex.lock().unwrap(); let (ref mut meta, ref mut dirty) = *guard; if let Some(s) = meta.slices.iter_mut().find(|s| s.index == slice.index) && let Some(ref mut cs) = s.channel_split { for (group_idx, circuit_path) in &group_circuits { if let Some(group) = cs.groups.iter_mut().find(|g| g.group_idx == *group_idx) { group.jstprove_circuit_path = Some(circuit_path.clone()); } } *dirty = true; } } Ok(CompileOutcome::CompiledDimSplit) => { let count = compiled_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; tracing::info!(slice = slice.index, count, "compiled dim-split template"); let mut guard = meta_mutex.lock().unwrap(); let (ref mut meta, ref mut dirty) = *guard; if let Some(s) = meta.slices.iter_mut().find(|s| s.index == slice.index) && let Some(ref mut ds) = s.dim_split { ds.jstprove_circuit_path = Some(format!( "slice_{}/jstprove/dim_split/circuit.bundle", slice.index )); *dirty = true; } } Ok(CompileOutcome::Skipped) => { tracing::info!(slice = slice.index, "skipped (unsupported ops)") } Ok(CompileOutcome::SkippedOverSize { estimated, threshold, }) => { tracing::info!( slice = slice.index, estimated, threshold, "skipped (estimated constraints exceed threshold)" ) } Err(e) => { // Per-slice compile failure is recoverable: the // summary log at the end of compile_slices // already surfaces the aggregate via warn!, and // the caller decides whether to continue with // partial coverage. Emitting error! here would // spam CI for an outcome that ok_if_no_failures // handles structurally. tracing::warn!(slice = slice.index, error = %e, "compilation failed"); errors.lock().unwrap().push((slice.index, e)); } } }); }); let errors = errors.into_inner().unwrap(); let (metadata, cs_dirty) = meta_mutex.into_inner().unwrap(); if cs_dirty { // Swallowing the save failure would let downstream // analyze / run / package observe an in-memory set of // materialised channel / dim-split circuit paths that the // on-disk metadata doesn't know about -- the very problem // normalize_split_metadata exists to prevent. Propagate. metadata.save(&meta_path)?; tracing::info!("persisted split circuit paths to metadata"); } let compiled_count = compiled_count.load(std::sync::atomic::Ordering::Relaxed); if errors.is_empty() { tracing::info!(count = compiled_count, "all slices compiled"); } else { tracing::warn!( compiled = compiled_count, failed = errors.len(), "compilation completed with errors; failed slices fall back to ONNX execution if the caller allows partial coverage" ); for (idx, e) in &errors { tracing::warn!(slice = idx, error = %e, "slice compilation failed"); } } Ok(CompileReport { compiled: compiled_count, failed: errors, }) } struct SliceAnalysis { compatible: bool, data_movement_only: bool, } const DATA_MOVEMENT_OPS: &[&str] = &[ "Reshape", "Transpose", "Flatten", "Squeeze", "Unsqueeze", "Identity", "Concat", "Split", "Gather", "Slice", "Expand", "Tile", "Cast", ]; fn analyze_slice_onnx(onnx_path: &Path, jstprove_ops: &[&str]) -> Result { let model = onnx_proto::load_model(onnx_path)?; let graph = model .graph .as_ref() .ok_or_else(|| DsperseError::Slicer(format!("no graph in {}", onnx_path.display())))?; let compatible = graph .node .iter() .all(|n| jstprove_ops.contains(&n.op_type.as_str())); let data_movement_only = !graph.node.is_empty() && graph .node .iter() .all(|n| DATA_MOVEMENT_OPS.contains(&n.op_type.as_str())); Ok(SliceAnalysis { compatible, data_movement_only, }) } pub(super) fn compute_circuit_signature(tmpl_path: &Path, curve: Option<&str>) -> Result { use sha2::{Digest, Sha256}; fn hash_bytes(hasher: &mut Sha256, b: &[u8]) { hasher.update((b.len() as u64).to_le_bytes()); hasher.update(b); } let model = onnx_proto::load_model(tmpl_path)?; let graph = model .graph .as_ref() .ok_or_else(|| DsperseError::Slicer("no graph for signature".into()))?; let mut hasher = Sha256::new(); if let Some(c) = curve { hash_bytes(&mut hasher, c.as_bytes()); } hasher.update((graph.node.len() as u64).to_le_bytes()); for node in &graph.node { hash_bytes(&mut hasher, node.op_type.as_bytes()); hasher.update((node.input.len() as u64).to_le_bytes()); for inp in &node.input { hash_bytes(&mut hasher, inp.as_bytes()); } hasher.update((node.output.len() as u64).to_le_bytes()); for out in &node.output { hash_bytes(&mut hasher, out.as_bytes()); } hasher.update((node.attribute.len() as u64).to_le_bytes()); for attr in &node.attribute { hash_bytes(&mut hasher, attr.name.as_bytes()); hasher.update(attr.r#type.to_le_bytes()); hasher.update(attr.i.to_le_bytes()); hasher.update(attr.f.to_le_bytes()); hash_bytes(&mut hasher, &attr.s); hasher.update((attr.ints.len() as u64).to_le_bytes()); for v in &attr.ints { hasher.update(v.to_le_bytes()); } hasher.update((attr.floats.len() as u64).to_le_bytes()); for v in &attr.floats { hasher.update(v.to_le_bytes()); } hasher.update((attr.strings.len() as u64).to_le_bytes()); for v in &attr.strings { hash_bytes(&mut hasher, v); } } } let init_names: std::collections::HashSet<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect(); for vi in &graph.input { if init_names.contains(vi.name.as_str()) { continue; } if let Some(shape) = onnx_proto::shape_from_value_info(vi) { hasher.update((shape.len() as u64).to_le_bytes()); for d in &shape { hasher.update(d.to_le_bytes()); } } if let Some(dt) = onnx_proto::elem_type_from_value_info(vi) { hasher.update(dt.to_le_bytes()); } } for vi in &graph.output { if let Some(shape) = onnx_proto::shape_from_value_info(vi) { hasher.update((shape.len() as u64).to_le_bytes()); for d in &shape { hasher.update(d.to_le_bytes()); } } if let Some(dt) = onnx_proto::elem_type_from_value_info(vi) { hasher.update(dt.to_le_bytes()); } } hasher.update((graph.initializer.len() as u64).to_le_bytes()); for init in &graph.initializer { hasher.update((init.dims.len() as u64).to_le_bytes()); for d in &init.dims { hasher.update(d.to_le_bytes()); } hasher.update(init.data_type.to_le_bytes()); } let hash = hasher.finalize(); Ok(format!("{:x}", hash)) } /// Bundle-aware signature used at packaging time. Wraps the ONNX+curve /// hash from `compute_circuit_signature` with discriminators pulled /// from the compiled bundle so that two packages built from the same /// ONNX but under different proof configs, input-binding modes, or /// holographic/non-holographic flows land at distinct shas in the /// content-addressed registry. The compile-time cache lookups in /// `compile_single_slice` continue to use `compute_circuit_signature` /// directly because they key on pre-compile state where no bundle /// exists yet. pub(super) fn compute_bundle_signature( tmpl_path: &Path, curve: Option<&str>, bundle_dir: &Path, ) -> Result { use sha2::{Digest, Sha256}; let base = compute_circuit_signature(tmpl_path, curve)?; let mut hasher = Sha256::new(); hasher.update(base.as_bytes()); // Stability contract: any change to the on-wire / in-hash layout // of the bytes mixed in below will silently re-shuffle every // content-addressed component sha. The three inputs that must // stay byte-stable are // * `jstprove_circuits::proof_config::ProofConfig::config_id()` // (CONFIG_ID integers documented in proof_config.rs), // * `StampedProofConfig::version` (u32, per // ProofConfig::current_version), // * `CircuitParams::weights_as_inputs` serialization (bool). // If any of those change their encoding, bump the version tag in // the marker below (for example `bundle-disambiguator-v2`) so // downstream registries receive a deliberate re-shuffle rather // than a silent one. hasher.update(b"\x00bundle-disambiguator-v1\x00"); match jstprove_io::bundle::read_bundle_metadata::( bundle_dir, ) { Ok((Some(params), _)) => { hasher.update([1u8]); match params.proof_config { Some(stamped) => { hasher.update([1u8]); hasher.update((stamped.config.config_id() as u64).to_le_bytes()); hasher.update(stamped.version.to_le_bytes()); } None => hasher.update([0u8]), } hasher.update([u8::from(params.weights_as_inputs)]); } Ok((None, _)) => { hasher.update([0u8]); } Err(e) => { // A malformed or unreadable manifest is meaningfully // different from a bundle that legitimately carries no // metadata. Distinguish the two with separate // discriminator bytes so a corrupt bundle cannot collide // with a clean legacy bundle, and surface the failure in // the tracing log so operators investigating a shifted // sha have the underlying read error to reference. tracing::warn!( bundle = %bundle_dir.display(), error = %e, "bundle manifest read failed while computing bundle signature; using error discriminator" ); hasher.update([2u8]); } } hasher.update([u8::from(jstprove_io::bundle::bundle_has_vk(bundle_dir))]); Ok(format!("{:x}", hasher.finalize())) } fn summarize_onnx_ops(onnx_path: &Path) -> String { let model = match onnx_proto::load_model(onnx_path) { Ok(m) => m, Err(_) => return String::from("?"), }; let graph = match model.graph.as_ref() { Some(g) => g, None => return String::from("?"), }; let mut counts: std::collections::BTreeMap<&str, usize> = std::collections::BTreeMap::new(); for node in &graph.node { *counts.entry(node.op_type.as_str()).or_default() += 1; } counts .iter() .map(|(op, n)| { if *n > 1 { format!("{op}x{n}") } else { op.to_string() } }) .collect::>() .join(",") } #[derive(Debug, serde::Serialize)] pub struct SliceAnalysisReport { pub index: usize, pub backend: String, pub reason: String, pub estimated_constraints: Option, pub ops: String, pub tiled: bool, pub channel_split: bool, pub dim_split: bool, pub circuit_signature: Option, } /// Derive the three metrics SliceAnalysisReport carries from an /// ONNX file: op-summary string, constraint estimate, and curve- /// stamped circuit signature. Used from every analyze_slices /// branch that can point at a concrete representative ONNX (the /// slice's own .onnx for standard slices, the first channel /// group's .onnx for channel-split, the dim-split template /// ONNX for dim-split). Failure on any single metric is /// non-fatal: we emit empty / None for the affected field and /// continue so analyze never aborts on a partially-materialised /// slice. fn derive_slice_report_metrics( onnx_path: &Path, proof_config: Option<&str>, ) -> (String, Option, Option) { if !onnx_path.exists() { return (String::new(), None, None); } let ops = summarize_onnx_ops(onnx_path); let estimated = estimate_onnx_constraints(onnx_path).ok(); let signature = compute_circuit_signature(onnx_path, proof_config).ok(); (ops, estimated, signature) } pub fn analyze_slices( slices_dir: &Path, jstprove_ops: &[&str], skip_compile_over_size: Option, proof_config: Option<&str>, ) -> Result> { let meta_path = find_metadata_path(slices_dir).ok_or_else(|| { DsperseError::Metadata(format!( "no {} found in slices directory", crate::utils::paths::METADATA_FILE )) })?; let mut metadata = ModelMetadata::load(&meta_path)?; // Apply the same split-metadata normalisation compile_slices // performs so the backend / reason classifications below see // populated channel_split.groups, inferred dim_split template // paths, and stripped dim_split entries whose template never // materialised. Without this step analyze_slices misreports // slices whose split state is implicit in on-disk artefacts. normalize_split_metadata(slices_dir, &meta_path, &mut metadata)?; let mut reports = Vec::with_capacity(metadata.slices.len()); for slice in &metadata.slices { let slice_dir = slice_dir_path(slices_dir, slice.index); if !slice_dir.exists() { reports.push(SliceAnalysisReport { index: slice.index, backend: "missing".into(), reason: "slice directory not found".into(), estimated_constraints: None, ops: String::new(), tiled: slice.tiling.is_some(), channel_split: slice.channel_split.is_some(), dim_split: slice.dim_split.is_some(), circuit_signature: None, }); continue; } if let Some(ref cs) = slice.channel_split && !cs.groups.is_empty() { // Use the first channel-group ONNX as representative // for the reported metrics: every group in the split // shares the same per-chunk topology, so op summary, // constraint estimate, and circuit signature are // group-invariant and the first group is authoritative // for the backend's view of compilation cost. let group_path = slices_dir.join(&cs.groups[0].path); let (ops, estimated, circuit_signature) = derive_slice_report_metrics(&group_path, proof_config); reports.push(SliceAnalysisReport { index: slice.index, backend: "jstprove".into(), reason: "channel-split".into(), estimated_constraints: estimated, ops, tiled: slice.tiling.is_some(), channel_split: true, dim_split: false, circuit_signature, }); continue; } if let Some(ref ds) = slice.dim_split && let Some(ref tmpl_rel) = ds.template_path { // The dim-split template is the ONNX the backend // actually compiles (one circuit shared across every // group), so it is the correct source for the // reported ops / constraint estimate / circuit // signature. let tmpl_path = slices_dir.join(tmpl_rel); let (ops, estimated, circuit_signature) = derive_slice_report_metrics(&tmpl_path, proof_config); reports.push(SliceAnalysisReport { index: slice.index, backend: "jstprove".into(), reason: "dim-split".into(), estimated_constraints: estimated, ops, tiled: slice.tiling.is_some(), channel_split: false, dim_split: true, circuit_signature, }); continue; } let onnx_path = match resolve_compile_onnx(slices_dir, slice) { Ok(p) => p, Err(_) => { // resolve_compile_onnx failing means the slice has // no ONNX artefact on disk at all. That is a // genuine "missing" state (the analyse footer // already has a dedicated missing count), not an // "onnx-backend-compatible" slice. reports.push(SliceAnalysisReport { index: slice.index, backend: "missing".into(), reason: "onnx not found".into(), estimated_constraints: None, ops: String::new(), tiled: slice.tiling.is_some(), channel_split: false, dim_split: false, circuit_signature: None, }); continue; } }; if !onnx_path.exists() { // Same reasoning as the resolve_compile_onnx Err branch // above: path was resolvable by metadata but the file // is absent, so the slice is missing rather than ONNX- // compatible. reports.push(SliceAnalysisReport { index: slice.index, backend: "missing".into(), reason: "onnx not found".into(), estimated_constraints: None, ops: String::new(), tiled: slice.tiling.is_some(), channel_split: false, dim_split: false, circuit_signature: None, }); continue; } let ops = summarize_onnx_ops(&onnx_path); let analysis = analyze_slice_onnx(&onnx_path, jstprove_ops); let estimated = estimate_onnx_constraints(&onnx_path).ok(); let sig = compute_circuit_signature(&onnx_path, proof_config).ok(); let (backend, reason) = match analysis { Ok(a) if !a.compatible => ("onnx", "unsupported ops"), Ok(a) if a.data_movement_only => ("onnx", "data movement only"), Ok(_) => { if let (Some(est), Some(thresh)) = (estimated, skip_compile_over_size) { if est > thresh { ("onnx", "exceeds size threshold") } else { ("jstprove", "compilable") } } else { ("jstprove", "compilable") } } Err(_) => ("onnx", "analysis failed"), }; reports.push(SliceAnalysisReport { index: slice.index, backend: backend.into(), reason: reason.into(), estimated_constraints: estimated, ops, tiled: slice.tiling.is_some(), channel_split: false, dim_split: slice.dim_split.is_some(), circuit_signature: sig, }); } Ok(reports) } fn estimate_onnx_constraints(onnx_path: &Path) -> Result { let model = onnx_proto::load_model(onnx_path)?; let graph = model .graph .as_ref() .ok_or_else(|| DsperseError::Slicer(format!("no graph in {}", onnx_path.display())))?; let shapes = extract_graph_shapes(graph); Ok(estimate_slice_constraints(&graph.node, &shapes)) } fn extract_graph_shapes( graph: &onnx_proto::GraphProto, ) -> std::collections::HashMap> { let mut shapes = std::collections::HashMap::new(); let extract_vi_shape = |vi: &onnx_proto::ValueInfoProto| -> Option<(String, Vec)> { let tp = vi.r#type.as_ref()?; if let Some(onnx_proto::onnx::type_proto::Value::TensorType(ref tt)) = tp.value { let dims: Vec = tt .shape .as_ref()? .dim .iter() .filter_map(|d| { if let Some(onnx_proto::onnx::tensor_shape_proto::dimension::Value::DimValue( v, )) = d.value { Some(v) } else { None } }) .collect(); if !dims.is_empty() { return Some((vi.name.clone(), dims)); } } None }; for vi in graph .input .iter() .chain(graph.output.iter()) .chain(graph.value_info.iter()) { if let Some((name, dims)) = extract_vi_shape(vi) { shapes.insert(name, dims); } } for init in &graph.initializer { if !init.name.is_empty() && !init.dims.is_empty() { shapes.insert(init.name.clone(), init.dims.clone()); } } shapes } fn normalize_slice_for_backend(onnx_path: &Path) -> Result> { let mut model = onnx_proto::load_model(onnx_path)?; let changes = onnx_proto::normalize_for_circuit_backend(&mut model); if changes == 0 { return Ok(None); } let normalized = onnx_path.with_extension("backend.onnx"); onnx_proto::save_model(&model, &normalized)?; Ok(Some(normalized)) } #[allow(clippy::too_many_arguments)] fn compile_single_slice( slices_dir: &Path, slice: &crate::schema::metadata::SliceMetadata, backend: &JstproveBackend, proof_config: jstprove_circuits::api::ProofConfigType, weights_as_inputs: bool, jstprove_ops: &[&str], exclude_from_wai: &std::collections::HashSet, skip_compile_over_size: Option, circuit_cache: &CircuitCache, traced_shapes: Option<&std::collections::HashMap>>, holographic: bool, ) -> Result { let slice_dir = slice_dir_path(slices_dir, slice.index); if !slice_dir.exists() { return Err(DsperseError::Pipeline(format!( "slice directory not found: {}", slice_dir.display() ))); } if let Some(ref cs) = slice.channel_split && !cs.groups.is_empty() { return compile_channel_split_slice( slices_dir, slice, cs, backend, proof_config, jstprove_ops, exclude_from_wai, skip_compile_over_size, circuit_cache, traced_shapes, holographic, ); } if let Some(ref ds) = slice.dim_split && let Some(ref tmpl_rel) = ds.template_path { let tmpl_path = slices_dir.join(tmpl_rel); if tmpl_path.exists() { return compile_dim_split_template( slices_dir, slice, &tmpl_path, backend, proof_config, jstprove_ops, exclude_from_wai, skip_compile_over_size, circuit_cache, traced_shapes, holographic, ); } } let onnx_path = resolve_compile_onnx(slices_dir, slice)?; if !onnx_path.exists() { return Err(DsperseError::Pipeline(format!( "ONNX model not found for slice {}: {}", slice.index, onnx_path.display() ))); } let analysis = analyze_slice_onnx(&onnx_path, jstprove_ops)?; if !analysis.compatible { return Ok(CompileOutcome::Skipped); } if analysis.data_movement_only { tracing::info!(slice = slice.index, "skipped (data movement only)"); return Ok(CompileOutcome::Skipped); } // The threshold gate needs a concrete estimate; the debug // block below can reuse it so we only re-parse the slice ONNX // for constraint counting once per slice. let mut estimated: Option = None; if let Some(threshold) = skip_compile_over_size { let est = estimate_onnx_constraints(&onnx_path)?; estimated = Some(est); if est > threshold { return Ok(CompileOutcome::SkippedOverSize { estimated: est, threshold, }); } } let jst_dir = slice_dir.join("jstprove"); std::fs::create_dir_all(&jst_dir).map_err(|e| DsperseError::io(e, &jst_dir))?; let circuit_path = jst_dir.join("circuit.bundle"); if circuit_path.is_dir() { match backend.load_params(&circuit_path) { Ok(_) => { tracing::info!(slice = slice.index, "already compiled, skipping"); if holographic && !jstprove_io::bundle::bundle_has_vk(&circuit_path) { run_holographic_setup(backend, &circuit_path, slice.index, "slice")?; } return Ok(CompileOutcome::Compiled); } Err(e) => { tracing::warn!(slice = slice.index, error = %e, "cached circuit invalid, recompiling"); std::fs::remove_dir_all(&circuit_path) .map_err(|e| DsperseError::io(e, &circuit_path))?; } } } let effective_wai = weights_as_inputs; // The diagnostic bundle re-parses the slice ONNX (once for the // op summary, once for the constraint estimate if we didn't // already gate through it above). Skip that work when debug // tracing is disabled -- in a release build across hundreds of // slices it adds up. if tracing::enabled!(tracing::Level::DEBUG) { if estimated.is_none() { estimated = estimate_onnx_constraints(&onnx_path).ok(); } let op_summary = summarize_onnx_ops(&onnx_path); tracing::debug!( slice = slice.index, onnx = %onnx_path.display(), estimated_constraints = ?estimated, weights_as_inputs = effective_wai, ops = %op_summary, tiled = slice.tiling.is_some(), channel_split = slice.channel_split.is_some(), dim_split = slice.dim_split.is_some(), "compiling slice" ); } let compile_onnx = normalize_slice_for_backend(&onnx_path)?; let (params, architecture, wandb) = converter::prepare_jstprove_artifacts_filtered( compile_onnx.as_ref().unwrap_or(&onnx_path), effective_wai, exclude_from_wai, traced_shapes, )?; std::panic::catch_unwind(|| { backend.compile(&circuit_path, proof_config, params, architecture, wandb) }) .map_err(|p| { let msg = p .downcast_ref::<&str>() .copied() .or_else(|| p.downcast_ref::().map(String::as_str)) .unwrap_or("unknown panic"); DsperseError::Backend(format!("jstprove panicked: {msg}")) })??; if holographic { run_holographic_setup(backend, &circuit_path, slice.index, "slice")?; } Ok(CompileOutcome::Compiled) } /// Result of a [`setup_holographic_for_slices`] invocation. Mirrors /// the structure of [`CompileReport`] so callers can surface /// per-slice failures with the same handling. #[derive(Debug, Default)] pub struct HolographicSetupReport { pub processed: usize, pub skipped_already_present: usize, pub failed: Vec<(usize, DsperseError)>, } impl HolographicSetupReport { pub fn ok_if_no_failures(self) -> Result { if self.failed.is_empty() { Ok(self) } else { Err(DsperseError::Pipeline(format!( "setup_holographic_for_slices: {} slice bundle(s) failed", self.failed.len() ))) } } } /// Run holographic GKR setup over every compiled bundle under /// `slices_dir`. Walks the slice metadata and, for each slice, /// processes the conventional bundle paths produced by /// [`compile_slices`]: standard (`jstprove/circuit.bundle`), /// channel-split (`jstprove/shared/circuit.bundle`), and dim-split /// template (`jstprove/dim_split/circuit.bundle`). /// /// Bundles that already carry a `vk.bin` are skipped unless /// `overwrite` is set, so this function is idempotent and cheap to /// re-run after a partial failure. pub fn setup_holographic_for_slices( slices_dir: &Path, backend: &JstproveBackend, parallel: usize, overwrite: bool, ) -> Result { let meta_path = find_metadata_path(slices_dir).ok_or_else(|| { DsperseError::Metadata(format!( "no {} found in slices directory", crate::utils::paths::METADATA_FILE )) })?; let metadata = ModelMetadata::load(&meta_path)?; let mut targets: Vec<(usize, &'static str, PathBuf)> = Vec::new(); for slice in &metadata.slices { let slice_dir = slice_dir_path(slices_dir, slice.index); let candidates: [(&'static str, PathBuf); 3] = [ ("slice", slice_dir.join("jstprove").join("circuit.bundle")), ( "channel-split-shared", slice_dir .join("jstprove") .join("shared") .join("circuit.bundle"), ), ( "dim-split-template", slice_dir .join("jstprove") .join("dim_split") .join("circuit.bundle"), ), ]; for (kind, path) in candidates { if path.is_dir() { targets.push((slice.index, kind, path)); } } } tracing::info!( bundles = targets.len(), parallel, overwrite, "running holographic GKR setup over compiled bundles" ); let pool = rayon::ThreadPoolBuilder::new() .num_threads(parallel) .build() .map_err(|e| DsperseError::Pipeline(format!("thread pool: {e}")))?; let processed = std::sync::atomic::AtomicUsize::new(0); let skipped = std::sync::atomic::AtomicUsize::new(0); let errors: std::sync::Mutex> = std::sync::Mutex::new(Vec::new()); pool.install(|| { targets .par_iter() .for_each(|(slice_idx, kind, bundle_path)| { if !overwrite && jstprove_io::bundle::bundle_has_vk(bundle_path) { skipped.fetch_add(1, std::sync::atomic::Ordering::Relaxed); tracing::info!( slice = *slice_idx, kind, path = %bundle_path.display(), "vk.bin already present, skipping (pass --overwrite to regenerate)" ); return; } match run_holographic_setup(backend, bundle_path, *slice_idx, kind) { Ok(()) => { processed.fetch_add(1, std::sync::atomic::Ordering::Relaxed); } Err(e) => { tracing::warn!(slice = *slice_idx, kind, error = %e, "holographic setup failed"); errors.lock().unwrap().push((*slice_idx, e)); } } }); }); Ok(HolographicSetupReport { processed: processed.load(std::sync::atomic::Ordering::Relaxed), skipped_already_present: skipped.load(std::sync::atomic::Ordering::Relaxed), failed: errors.into_inner().unwrap(), }) } fn run_holographic_setup( backend: &JstproveBackend, circuit_path: &Path, slice_idx: usize, kind: &'static str, ) -> Result<()> { tracing::info!( slice = slice_idx, kind, path = %circuit_path.display(), "running holographic GKR setup" ); std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { backend.setup_holographic_vk(circuit_path) })) .map_err(|p| { let msg = p .downcast_ref::<&str>() .copied() .or_else(|| p.downcast_ref::().map(String::as_str)) .unwrap_or("unknown panic"); DsperseError::Backend(format!( "jstprove panicked during holographic setup on slice {slice_idx} ({kind}): {msg}" )) })??; tracing::info!(slice = slice_idx, kind, "holographic vk persisted"); Ok(()) } fn populate_channel_split_groups( slices_dir: &Path, slice_idx: usize, cs: &mut crate::schema::tiling::ChannelSplitInfo, ) -> Result { let groups_dir = slices_dir .join(format!("slice_{slice_idx}")) .join("payload") .join("channel_groups"); if !groups_dir.exists() { return Ok(false); } let cpg = cs.channels_per_group; let mut groups = Vec::with_capacity(cs.num_groups); for g in 0..cs.num_groups { let c_start = g.checked_mul(cpg).ok_or_else(|| { DsperseError::Slicer(format!("overflow computing c_start for group {g}")) })?; let c_end = (g + 1) .checked_mul(cpg) .map(|v| v.min(cs.c_in)) .ok_or_else(|| { DsperseError::Slicer(format!("overflow computing c_end for group {g}")) })?; let rel_path = format!("slice_{slice_idx}/payload/channel_groups/group_{g}.onnx"); let abs_path = slices_dir.join(&rel_path); if !abs_path.exists() { tracing::warn!( slice = slice_idx, group = g, "expected group ONNX not found, skipping population" ); return Ok(false); } groups.push(crate::schema::tiling::ChannelGroupInfo { group_idx: g, c_start, c_end, path: rel_path, jstprove_circuit_path: None, jstprove_settings_path: None, }); } let bias_rel = format!("slice_{slice_idx}/payload/channel_groups/bias.msgpack"); if slices_dir.join(&bias_rel).exists() { cs.bias_path = Some(bias_rel); } tracing::info!( slice = slice_idx, groups = groups.len(), "populated channel split groups from materialized files" ); cs.groups = groups; Ok(true) } #[allow(clippy::too_many_arguments)] fn compile_channel_split_slice( slices_dir: &Path, slice: &crate::schema::metadata::SliceMetadata, cs: &crate::schema::tiling::ChannelSplitInfo, backend: &JstproveBackend, proof_config: jstprove_circuits::api::ProofConfigType, jstprove_ops: &[&str], exclude_from_wai: &std::collections::HashSet, skip_compile_over_size: Option, circuit_cache: &CircuitCache, traced_shapes: Option<&std::collections::HashMap>>, holographic: bool, ) -> Result { let slice_dir = slice_dir_path(slices_dir, slice.index); let jst_dir = slice_dir.join("jstprove"); std::fs::create_dir_all(&jst_dir).map_err(|e| DsperseError::io(e, &jst_dir))?; let shared_circuit_rel = format!("slice_{}/jstprove/shared/circuit.bundle", slice.index); let shared_circuit_path = jst_dir.join("shared").join("circuit.bundle"); // Treat an existing shared bundle the same way the standard- // slice path does: try to load it; if load_params rejects it // (version drift, partial write, corruption), drop the stale // directory and fall through to the compile-fresh branch so a // single bad bundle doesn't permanently wedge every slice in // the channel-split group. The fresh-build code below is // unchanged and will re-populate from the circuit cache or // via backend.compile as appropriate. let mut needs_build = !shared_circuit_path.is_dir(); if !needs_build { match backend.load_params(&shared_circuit_path) { Ok(_) => { tracing::info!( slice = slice.index, "shared circuit already compiled, reusing" ); } Err(e) => { tracing::warn!( slice = slice.index, error = %e, "cached shared circuit invalid, recompiling" ); std::fs::remove_dir_all(&shared_circuit_path) .map_err(|e| DsperseError::io(e, &shared_circuit_path))?; needs_build = true; } } } if needs_build { let first_group = cs.groups.first().ok_or_else(|| { DsperseError::Pipeline(format!("slice {} channel_split has no groups", slice.index)) })?; let onnx_path = slices_dir.join(&first_group.path); if !onnx_path.exists() { return Err(DsperseError::Pipeline(format!( "channel group ONNX not found: {}", onnx_path.display() ))); } let analysis = analyze_slice_onnx(&onnx_path, jstprove_ops)?; if !analysis.compatible { return Err(DsperseError::Pipeline(format!( "slice {} group 0 has unsupported ops for circuit compilation", slice.index ))); } if let Some(threshold) = skip_compile_over_size { let estimated = estimate_onnx_constraints(&onnx_path)?; if estimated > threshold { return Ok(CompileOutcome::SkippedOverSize { estimated, threshold, }); } } let sig = compute_circuit_signature(&onnx_path, None)?; let cached = circuit_cache.lock().unwrap().get(&sig).cloned(); if let Some(ref cached_path) = cached && cached_path.is_dir() { let shared_dir = shared_circuit_path.parent().ok_or_else(|| { DsperseError::Pipeline("shared circuit path has no parent".into()) })?; std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?; copy_dir_recursive(cached_path, &shared_circuit_path)?; tracing::info!( slice = slice.index, sig = %sig, "reused cached channel-split circuit from prior slice" ); } else { let shared_dir = shared_circuit_path.parent().ok_or_else(|| { DsperseError::Pipeline("shared circuit path has no parent".into()) })?; std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?; tracing::info!( slice = slice.index, groups = cs.groups.len(), sig = %sig, "compiling shared channel group circuit (weights-as-inputs)" ); let (params, architecture, wandb) = converter::prepare_jstprove_artifacts_filtered( &onnx_path, true, exclude_from_wai, traced_shapes, )?; std::panic::catch_unwind(|| { backend.compile( &shared_circuit_path, proof_config, params, architecture, wandb, ) }) .map_err(|p| { let msg = p .downcast_ref::<&str>() .copied() .or_else(|| p.downcast_ref::().map(String::as_str)) .unwrap_or("unknown panic"); DsperseError::Backend(format!( "jstprove panicked on slice {} shared circuit: {msg}", slice.index )) })??; circuit_cache .lock() .unwrap() .insert(sig.clone(), shared_circuit_path.clone()); tracing::info!(slice = slice.index, sig = %sig, "shared circuit compiled"); } // One final load to match the cached-bundle branch's // invariant: the function returns only after we have seen // a viable shared circuit at shared_circuit_path. If the // freshly-built bundle still fails to load, a retry would // recurse indefinitely, so surface the error. backend.load_params(&shared_circuit_path).map_err(|e| { DsperseError::Pipeline(format!( "slice {} freshly-built shared circuit failed to load: {e}", slice.index )) })?; if holographic && !jstprove_io::bundle::bundle_has_vk(&shared_circuit_path) { // When the needs_build branch took the memcache-reuse // sub-path, copy_dir_recursive may already have brought // vk.bin across from the source bundle; skip the // expensive re-setup in that case and only run it when // the shared bundle genuinely lacks a vk (the // fresh-compile sub-path, or a source that raced us // before its own setup persisted). run_holographic_setup( backend, &shared_circuit_path, slice.index, "channel-split-shared", )?; } } else if holographic && !jstprove_io::bundle::bundle_has_vk(&shared_circuit_path) { // Cached bundle predates the holographic plumbing: backfill // the vk so reused circuits stay in sync with freshly-built // ones. run_holographic_setup( backend, &shared_circuit_path, slice.index, "channel-split-shared", )?; } let group_circuits: Vec<(usize, String)> = cs .groups .iter() .map(|g| (g.group_idx, shared_circuit_rel.clone())) .collect(); Ok(CompileOutcome::CompiledChannelSplit { group_circuits }) } #[allow(clippy::too_many_arguments)] fn compile_dim_split_template( slices_dir: &Path, slice: &crate::schema::metadata::SliceMetadata, tmpl_path: &Path, backend: &JstproveBackend, proof_config: jstprove_circuits::api::ProofConfigType, jstprove_ops: &[&str], exclude_from_wai: &std::collections::HashSet, skip_compile_over_size: Option, circuit_cache: &CircuitCache, _traced_shapes: Option<&std::collections::HashMap>>, holographic: bool, ) -> Result { let slice_dir = slice_dir_path(slices_dir, slice.index); let jst_dir = slice_dir.join("jstprove"); std::fs::create_dir_all(&jst_dir).map_err(|e| DsperseError::io(e, &jst_dir))?; let circuit_path = jst_dir.join("dim_split").join("circuit.bundle"); if circuit_path.is_dir() { match backend.load_params(&circuit_path) { Ok(_) => { tracing::info!( slice = slice.index, "dim-split template already compiled, reusing" ); if holographic && !jstprove_io::bundle::bundle_has_vk(&circuit_path) { // Backfill vk on cached bundles; see channel- // split branch above for the same rationale. run_holographic_setup( backend, &circuit_path, slice.index, "dim-split-template", )?; } return Ok(CompileOutcome::CompiledDimSplit); } Err(e) => { tracing::warn!(slice = slice.index, error = %e, "cached dim-split circuit invalid, recompiling"); std::fs::remove_dir_all(&circuit_path) .map_err(|e| DsperseError::io(e, &circuit_path))?; } } } let analysis = analyze_slice_onnx(tmpl_path, jstprove_ops)?; if !analysis.compatible { return Ok(CompileOutcome::Skipped); } if let Some(threshold) = skip_compile_over_size { let estimated = slice .dim_split .as_ref() .map(|ds| ds.estimated_group_constraints) .filter(|&e| e > 0) .or_else(|| match estimate_onnx_constraints(tmpl_path) { Ok(e) => Some(e), Err(err) => { // We can't turn an unknown cost into a safe // gating decision, so fall through and let the // compile attempt surface the real error rather // than silently treating the slice as tiny. tracing::warn!( slice = slice.index, onnx = %tmpl_path.display(), error = %err, "skip_compile_over_size: constraint estimate failed; proceeding to compile" ); None } }); if let Some(estimated) = estimated && estimated > threshold { return Ok(CompileOutcome::SkippedOverSize { estimated, threshold, }); } } let sig = compute_circuit_signature(tmpl_path, None)?; let cached = circuit_cache.lock().unwrap().get(&sig).cloned(); if let Some(ref cached_path) = cached && cached_path.is_dir() { let shared_dir = circuit_path .parent() .ok_or_else(|| DsperseError::Pipeline("dim-split circuit path has no parent".into()))?; std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?; copy_dir_recursive(cached_path, &circuit_path)?; tracing::info!( slice = slice.index, sig = %sig, "reused cached dim-split circuit" ); // circuit_cache can hand back a source bundle that was // inserted before its own run_holographic_setup finished // (the fresh-build branch below inserts the sig before it // persists vk.bin), so a parallel racer can snapshot a // pre-vk source and copy_dir_recursive a bundle missing // vk.bin. Mirror the channel-split reuse branch and // backfill on the copy so every reused dim-split bundle // ends up in the same shape as a freshly-compiled one. if holographic && !jstprove_io::bundle::bundle_has_vk(&circuit_path) { run_holographic_setup(backend, &circuit_path, slice.index, "dim-split-template")?; } return Ok(CompileOutcome::CompiledDimSplit); } let shared_dir = circuit_path .parent() .ok_or_else(|| DsperseError::Pipeline("dim-split circuit path has no parent".into()))?; std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?; tracing::info!( slice = slice.index, sig = %sig, "compiling dim-split template (weights-as-inputs)" ); // Do NOT pass the original traced_shapes when compiling dim-split // templates. The template has rewritten shapes (dim_size → epg) that // differ from the original model's traced shapes. If traced_shapes // is passed, jstprove uses the original (larger) shapes and the // Transpose/Reshape validation fails on the mismatch. let (params, architecture, wandb) = converter::prepare_jstprove_artifacts_filtered(tmpl_path, true, exclude_from_wai, None)?; std::panic::catch_unwind(|| { backend.compile(&circuit_path, proof_config, params, architecture, wandb) }) .map_err(|p| { let msg = p .downcast_ref::<&str>() .copied() .or_else(|| p.downcast_ref::().map(String::as_str)) .unwrap_or("unknown panic"); DsperseError::Backend(format!( "jstprove panicked on slice {} dim-split template: {msg}", slice.index )) })??; circuit_cache .lock() .unwrap() .insert(sig.clone(), circuit_path.clone()); tracing::info!(slice = slice.index, sig = %sig, "dim-split template compiled"); if holographic { run_holographic_setup(backend, &circuit_path, slice.index, "dim-split-template")?; } Ok(CompileOutcome::CompiledDimSplit) } fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<()> { std::fs::create_dir_all(dst).map_err(|e| DsperseError::io(e, dst))?; for entry in std::fs::read_dir(src).map_err(|e| DsperseError::io(e, src))? { let entry = entry.map_err(|e| DsperseError::io(e, src))?; let ty = entry.file_type().map_err(|e| DsperseError::io(e, src))?; let dst_path = dst.join(entry.file_name()); if ty.is_dir() { copy_dir_recursive(&entry.path(), &dst_path)?; } else { std::fs::copy(entry.path(), &dst_path).map_err(|e| DsperseError::io(e, &dst_path))?; } } Ok(()) } fn resolve_compile_onnx( slices_dir: &Path, slice: &crate::schema::metadata::SliceMetadata, ) -> Result { if let Some(ref tiling) = slice.tiling && let Some(ref tile) = tiling.tile { let tile_path = slices_dir.join(&tile.path); if tile_path.exists() { tracing::info!( slice = slice.index, path = %tile_path.display(), "using tile ONNX" ); return Ok(tile_path); } } slice.resolve_onnx(slices_dir) } #[cfg(test)] mod tests { use super::*; use crate::schema::metadata::{ Compilation, Dependencies, SliceMetadata, SliceShapeWrapper, TensorShape, }; use crate::schema::tiling::{TileInfo, TilingInfo}; fn test_models_dir() -> std::path::PathBuf { std::path::PathBuf::from(concat!(env!("CARGO_MANIFEST_DIR"), "/../../tests/models")) } fn make_slice_metadata(index: usize, path: &str) -> SliceMetadata { SliceMetadata { index, filename: format!("slice_{index}.onnx"), path: path.to_string(), relative_path: path.to_string(), shape: SliceShapeWrapper { tensor_shape: TensorShape::default(), }, dependencies: Dependencies { input: vec![], output: vec![], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, } } const TEST_OPS: &[&str] = &["Conv", "Gemm", "MatMul"]; #[test] fn analyze_slice_onnx_nonexistent() { let result = analyze_slice_onnx(Path::new("/nonexistent.onnx"), TEST_OPS); assert!(result.is_err()); } #[test] fn analyze_slice_onnx_test_model() { let model_path = test_models_dir().join("net/model.onnx"); assert!( model_path.exists(), "fixture missing: {}", model_path.display() ); let analysis = analyze_slice_onnx(&model_path, TEST_OPS).unwrap(); assert!(!analysis.compatible); } #[test] fn analyze_slice_onnx_with_initializers() { let tmp = tempfile::tempdir().unwrap(); let path = tmp.path().join("with_init.onnx"); let model = onnx_proto::ModelProto { graph: Some(onnx_proto::GraphProto { node: vec![onnx_proto::make_node("Conv", vec![], vec![], vec![])], initializer: vec![onnx_proto::make_tensor( "weight", 1, &[3, 3, 3, 3], vec![0.0; 81], )], ..Default::default() }), ..Default::default() }; onnx_proto::save_model(&model, &path).unwrap(); let analysis = analyze_slice_onnx(&path, &["Conv"]).unwrap(); assert!(analysis.compatible); } #[test] fn analyze_slice_onnx_without_initializers() { let tmp = tempfile::tempdir().unwrap(); let path = tmp.path().join("no_init.onnx"); let model = onnx_proto::ModelProto { graph: Some(onnx_proto::GraphProto { node: vec![onnx_proto::make_node("Relu", vec![], vec![], vec![])], initializer: vec![], ..Default::default() }), ..Default::default() }; onnx_proto::save_model(&model, &path).unwrap(); let analysis = analyze_slice_onnx(&path, &["Relu"]).unwrap(); assert!(analysis.compatible); } #[test] fn resolve_compile_onnx_no_tiling() { let tmp = tempfile::tempdir().unwrap(); let slices_dir = tmp.path(); let slice_dir = slices_dir.join("slice_0"); std::fs::create_dir_all(&slice_dir).unwrap(); let meta = make_slice_metadata(0, "slice_0.onnx"); let path = resolve_compile_onnx(slices_dir, &meta).unwrap(); assert!(path.ends_with("slice_0.onnx")); } #[test] fn resolve_compile_onnx_with_tile() { let tmp = tempfile::tempdir().unwrap(); let slices_dir = tmp.path(); let tile_path = slices_dir.join("slice_0/payload/tiles/tile.onnx"); std::fs::create_dir_all(tile_path.parent().unwrap()).unwrap(); std::fs::write(&tile_path, b"dummy").unwrap(); let mut meta = make_slice_metadata(0, "slice_0.onnx"); meta.tiling = Some(TilingInfo { slice_idx: 0, tile_size: 8, num_tiles: 4, tiles_y: 2, tiles_x: 2, halo: [1, 1, 1, 1], out_tile: [4, 4], stride: [1, 1], c_in: 3, c_out: 16, input_name: "input".into(), output_name: "output".into(), input_names: vec![], ndim: 4, h: 16, w: 16, tile: Some(TileInfo { path: "slice_0/payload/tiles/tile.onnx".into(), conv_out: [4, 4], jstprove_circuit_path: None, }), tiles: None, segment_size: None, total_elements: None, original_shape: vec![], }); let path = resolve_compile_onnx(slices_dir, &meta).unwrap(); assert!(path.ends_with("tile.onnx")); } #[test] fn resolve_compile_onnx_tile_missing_falls_back() { let tmp = tempfile::tempdir().unwrap(); let slices_dir = tmp.path(); let slice_dir = slices_dir.join("slice_0"); std::fs::create_dir_all(&slice_dir).unwrap(); let mut meta = make_slice_metadata(0, "slice_0.onnx"); meta.tiling = Some(TilingInfo { slice_idx: 0, tile_size: 8, num_tiles: 4, tiles_y: 2, tiles_x: 2, halo: [1, 1, 1, 1], out_tile: [4, 4], stride: [1, 1], c_in: 3, c_out: 16, input_name: "input".into(), output_name: "output".into(), input_names: vec![], ndim: 4, h: 16, w: 16, tile: Some(TileInfo { path: "slice_0/payload/tiles/nonexistent.onnx".into(), conv_out: [4, 4], jstprove_circuit_path: None, }), tiles: None, segment_size: None, total_elements: None, original_shape: vec![], }); let path = resolve_compile_onnx(slices_dir, &meta).unwrap(); assert!(path.ends_with("slice_0.onnx")); } fn write_identity_onnx(path: &Path) { let node = onnx_proto::NodeProto { op_type: "Relu".to_string(), input: vec!["x".to_string()], output: vec!["y".to_string()], ..Default::default() }; let graph = onnx_proto::make_graph( "g", vec![node], vec![onnx_proto::make_tensor_value_info("x", 1, &[1, 8])], vec![onnx_proto::make_tensor_value_info("y", 1, &[1, 8])], vec![], ); let model = onnx_proto::make_model(graph, 13); onnx_proto::save_model(&model, path).unwrap(); } #[test] fn bundle_signature_differs_from_circuit_signature_even_without_metadata() { let tmp = tempfile::tempdir().unwrap(); let onnx_path = tmp.path().join("slice.onnx"); write_identity_onnx(&onnx_path); let bundle_dir = tmp.path().join("bundle"); std::fs::create_dir_all(&bundle_dir).unwrap(); let base = compute_circuit_signature(&onnx_path, None).unwrap(); let bundle_sig = compute_bundle_signature(&onnx_path, None, &bundle_dir).unwrap(); assert_ne!( base, bundle_sig, "bundle signature must always include discriminator bytes" ); let bundle_sig_again = compute_bundle_signature(&onnx_path, None, &bundle_dir).unwrap(); assert_eq!( bundle_sig, bundle_sig_again, "bundle signature must be deterministic" ); } #[test] fn bundle_signature_disambiguates_vk_presence() { let tmp = tempfile::tempdir().unwrap(); let onnx_path = tmp.path().join("slice.onnx"); write_identity_onnx(&onnx_path); let plain_bundle = tmp.path().join("plain"); let holo_bundle = tmp.path().join("holographic"); std::fs::create_dir_all(&plain_bundle).unwrap(); std::fs::create_dir_all(&holo_bundle).unwrap(); std::fs::write(holo_bundle.join("vk.bin"), b"vk-contents").unwrap(); let plain_sig = compute_bundle_signature(&onnx_path, None, &plain_bundle).unwrap(); let holo_sig = compute_bundle_signature(&onnx_path, None, &holo_bundle).unwrap(); assert_ne!( plain_sig, holo_sig, "holographic bundle must produce a distinct signature" ); } #[test] fn bundle_signature_disambiguates_proof_config_and_wai_on_metadata_branch() { use std::collections::HashMap; use jstprove_circuits::ProofSystem; use jstprove_circuits::api::{CircuitParamsType, ProofConfigType, StampedProofConfigType}; use jstprove_io::bundle::write_bundle; let tmp = tempfile::tempdir().unwrap(); let onnx_path = tmp.path().join("slice.onnx"); write_identity_onnx(&onnx_path); fn make_params(config: ProofConfigType, weights_as_inputs: bool) -> CircuitParamsType { CircuitParamsType { scale_base: 2, scale_exponent: 8, rescale_config: HashMap::new(), inputs: Vec::new(), outputs: Vec::new(), freivalds_reps: 1, n_bits_config: HashMap::new(), weights_as_inputs, proof_system: ProofSystem::default(), proof_config: Some(StampedProofConfigType::current(config)), logup_chunk_bits: None, public_inputs: Vec::new(), } } let bn254_bundle = tmp.path().join("bn254"); let goldi_bundle = tmp.path().join("goldilocks"); let bn254_wai_bundle = tmp.path().join("bn254-wai"); write_bundle( &bn254_bundle, &[1, 2, 3], &[4, 5, 6], Some(make_params(ProofConfigType::Bn254Raw, false)), None, false, ) .unwrap(); write_bundle( &goldi_bundle, &[1, 2, 3], &[4, 5, 6], Some(make_params(ProofConfigType::GoldilocksExt4Whir, false)), None, false, ) .unwrap(); write_bundle( &bn254_wai_bundle, &[1, 2, 3], &[4, 5, 6], Some(make_params(ProofConfigType::Bn254Raw, true)), None, false, ) .unwrap(); let sig_bn254 = compute_bundle_signature(&onnx_path, None, &bn254_bundle).unwrap(); let sig_goldi = compute_bundle_signature(&onnx_path, None, &goldi_bundle).unwrap(); let sig_bn254_wai = compute_bundle_signature(&onnx_path, None, &bn254_wai_bundle).unwrap(); assert_ne!( sig_bn254, sig_goldi, "config_id must discriminate bundles with different ProofConfig variants" ); assert_ne!( sig_bn254, sig_bn254_wai, "weights_as_inputs must discriminate bundles with the same ProofConfig" ); let sig_bn254_again = compute_bundle_signature(&onnx_path, None, &bn254_bundle).unwrap(); assert_eq!( sig_bn254, sig_bn254_again, "signature must be deterministic for the metadata branch" ); } } ================================================ FILE: crates/dsperse/src/pipeline/dim_split.rs ================================================ use std::collections::HashMap; use std::path::Path; use super::runner::{run_onnx_inference, run_onnx_inference_multi_named}; use super::tensor_store::TensorStore; use crate::backend::jstprove::JstproveBackend; use crate::error::{DsperseError, Result}; use crate::schema::execution::ExecutionInfo; use crate::schema::tiling::DimSplitKind; use crate::slicer::onnx_proto::TensorProto; #[allow(clippy::too_many_arguments)] pub(crate) fn execute_dim_split( slices_dir: &Path, _slice_run_dir: &Path, slice_id: &str, ds: &crate::schema::tiling::DimSplitInfo, target_shape: Option<&[i64]>, tensor_cache: &TensorStore, _backend: &JstproveBackend, donor_init_map: Option<&HashMap>, ) -> Result { let tmpl_rel = ds.template_path.as_ref().ok_or_else(|| { DsperseError::Pipeline(format!("{slice_id}: dim_split has no template_path")) })?; let tmpl_path = slices_dir.join(tmpl_rel); if !tmpl_path.exists() { return Err(DsperseError::Pipeline(format!( "{slice_id}: dim-split template not found: {}", tmpl_path.display() ))); } let use_matmul_split = matches!(ds.split_kind, DimSplitKind::MatMulOutputDim); let final_result = if use_matmul_split { execute_matmul_dim_split( slices_dir, slice_id, ds, target_shape, tensor_cache, &tmpl_path, donor_init_map, )? } else { execute_generic_dim_split(slice_id, ds, target_shape, tensor_cache, &tmpl_path)? }; Ok(crate::schema::execution::StrategyOutput { info: ExecutionInfo { method: crate::schema::execution::ExecutionMethod::DimSplit, success: true, error: None, witness_file: None, tile_exec_infos: Vec::new(), }, outputs: vec![(ds.output_name.clone(), final_result)], }) } #[allow(clippy::too_many_arguments)] fn execute_matmul_dim_split( slices_dir: &Path, slice_id: &str, ds: &crate::schema::tiling::DimSplitInfo, target_shape: Option<&[i64]>, tensor_cache: &TensorStore, tmpl_path: &Path, donor_init_map: Option<&HashMap>, ) -> Result> { let input_tensor = tensor_cache.get(&ds.input_name)?.clone(); let input_shape = input_tensor.shape().to_vec(); let k_dim = *input_shape.last().unwrap_or(&0); if ds.k_dim != 0 && k_dim != ds.k_dim { return Err(DsperseError::Pipeline(format!( "{slice_id}: runtime k_dim {} from input {:?} does not match metadata k_dim {}", k_dim, ds.input_name, ds.k_dim ))); } if k_dim == 0 { return Err(DsperseError::Pipeline(format!( "{slice_id}: dim-split input {:?} has zero-width last dim; expected k_dim > 0", ds.input_name ))); } let k_chunks = ds.k_chunks.max(1); let k_chunk_size = k_dim.div_ceil(k_chunks); let total_rows: usize = input_shape .iter() .take(input_shape.len().saturating_sub(1)) .product(); let flat_input = input_tensor .as_standard_layout() .into_owned() .into_shape_with_order(ndarray::IxDyn(&[total_rows, k_dim])) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: flatten input: {e}")))?; let slice_onnx_path = slices_dir .join(format!("slice_{}", ds.slice_idx)) .join("payload") .join(format!("slice_{}.onnx", ds.slice_idx)); let orig_model = crate::slicer::onnx_proto::load_model(&slice_onnx_path)?; let orig_graph = orig_model .graph .as_ref() .ok_or_else(|| DsperseError::Pipeline(format!("{slice_id}: slice ONNX has no graph")))?; let weight_name = ds.weight_name.as_ref().ok_or_else(|| { DsperseError::Pipeline(format!( "{slice_id}: dim_split missing weight_name in metadata" )) })?; let matmul_node = orig_graph .node .iter() .find(|n| { matches!(n.op_type.as_str(), "MatMul" | "Gemm") && n.input.iter().any(|i| i == weight_name) && n.input.iter().any(|i| i == &ds.input_name) && n.output.iter().any(|o| o == &ds.output_name) }) .ok_or_else(|| { DsperseError::Pipeline(format!( "{slice_id}: no MatMul/Gemm node matches weight={weight_name:?} input={:?} output={:?}", ds.input_name, ds.output_name )) })?; let trans_b = matmul_node.op_type == "Gemm" && crate::slicer::onnx_proto::get_attribute_int(matmul_node, "transB").unwrap_or(0) == 1; let full_weight: Vec = if let Some(map) = donor_init_map && let Some(t) = map.get(weight_name.as_str()) { crate::slicer::onnx_proto::tensor_to_f32(t) } else { let init = orig_graph .initializer .iter() .find(|i| i.name == *weight_name) .ok_or_else(|| { DsperseError::Pipeline(format!( "{slice_id}: weight {weight_name:?} not found in slice ONNX initializers" )) })?; crate::slicer::onnx_proto::tensor_to_f32(init) }; let expected_weight_len = ds.k_dim.saturating_mul(ds.n_dim); if expected_weight_len > 0 && full_weight.len() != expected_weight_len { return Err(DsperseError::Pipeline(format!( "{slice_id}: weight {weight_name:?} length {} does not match expected k_dim*n_dim = {}*{} = {}", full_weight.len(), ds.k_dim, ds.n_dim, expected_weight_len ))); } let n_dim = ds.n_dim; let tmpl_model = crate::slicer::onnx_proto::load_model(tmpl_path)?; let tmp_dir = tempfile::tempdir() .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: tmpdir: {e}")))?; let mut patched_paths: Vec = Vec::with_capacity(k_chunks); for kc in 0..k_chunks { let k_start = kc * k_chunk_size; let k_end = (k_start + k_chunk_size).min(k_dim); let actual_k = k_end.saturating_sub(k_start); let weight_chunk: Vec = if trans_b { let mut w = Vec::with_capacity(n_dim * k_chunk_size); for row_idx in 0..n_dim { let row_start = row_idx * k_dim + k_start; let avail = actual_k.min(full_weight.len().saturating_sub(row_start)); w.extend_from_slice(&full_weight[row_start..row_start + avail]); if avail < k_chunk_size { w.resize(w.len() + k_chunk_size - avail, 0.0); } } w } else { let mut w = Vec::with_capacity(k_chunk_size * n_dim); for ki in k_start..k_start + actual_k { let start = ki * n_dim; let end = start + n_dim; if end <= full_weight.len() { w.extend_from_slice(&full_weight[start..end]); } else { w.resize(w.len() + n_dim, 0.0); } } if actual_k < k_chunk_size { w.resize(k_chunk_size * n_dim, 0.0); } w }; let mut patched = tmpl_model.clone(); let graph = patched.graph.as_mut().ok_or_else(|| { DsperseError::Pipeline(format!( "{slice_id}: dim-split template at {} has no graph", tmpl_path.display() )) })?; let w_init = graph .initializer .iter_mut() .find(|i| i.name == "W") .ok_or_else(|| { DsperseError::Pipeline(format!( "{slice_id}: dim-split template at {} missing 'W' initializer", tmpl_path.display() )) })?; w_init.float_data = weight_chunk; w_init.raw_data.clear(); let patched_path = tmp_dir.path().join(format!("chunk_{kc}.onnx")); crate::slicer::onnx_proto::save_model(&patched, &patched_path)?; patched_paths.push(patched_path); } let mut row_outputs: Vec> = Vec::with_capacity(total_rows); for r in 0..total_rows { let full_row: Vec = flat_input .slice(ndarray::s![r, ..]) .iter() .copied() .collect(); let mut row_accum = vec![0.0f64; n_dim]; for (kc, patched_path) in patched_paths.iter().enumerate() { let k_start = kc * k_chunk_size; let k_end = (k_start + k_chunk_size).min(k_dim); let actual_k = k_end.saturating_sub(k_start); let mut input_chunk = vec![0.0f64; k_chunk_size]; if actual_k > 0 { input_chunk[..actual_k].copy_from_slice(&full_row[k_start..k_end]); } let input_arr = ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[1, k_chunk_size]), input_chunk) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: input chunk: {e}")))?; let out = run_onnx_inference(patched_path, &input_arr)?; if out.len() != n_dim { return Err(DsperseError::Pipeline(format!( "{slice_id}: dim-split k-chunk {kc} produced {} outputs, expected n_dim={n_dim}", out.len() ))); } for (acc, v) in row_accum.iter_mut().zip(out.iter().copied()) { *acc += v; } } let row_arr = ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[1, n_dim]), row_accum) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: row output: {e}")))?; row_outputs.push(row_arr); } let stacked: ndarray::ArrayD = if row_outputs.is_empty() { ndarray::ArrayD::zeros(ndarray::IxDyn(&[0, n_dim])) } else { ndarray::concatenate( ndarray::Axis(0), &row_outputs.iter().map(|a| a.view()).collect::>(), ) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: row concat: {e}")))? }; let output_shape_vec = resolve_output_shape(slice_id, &input_shape, n_dim, target_shape)?; let final_result = stacked .as_standard_layout() .into_owned() .into_shape_with_order(ndarray::IxDyn(&output_shape_vec)) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: dim-split reshape: {e}")))?; tracing::info!( slice = %slice_id, rows = total_rows, k_chunks = k_chunks, "executed dim-split (sequence + K tiled)" ); Ok(final_result) } fn execute_generic_dim_split( slice_id: &str, ds: &crate::schema::tiling::DimSplitInfo, target_shape: Option<&[i64]>, tensor_cache: &TensorStore, tmpl_path: &Path, ) -> Result> { use ndarray::Axis; let concat_axis = ds.concat_axis; let split_dim = ds.split_dim; let epg = ds.elements_per_group; let tmpl_model = crate::slicer::onnx_proto::load_model(tmpl_path)?; let tmpl_graph = tmpl_model .graph .as_ref() .ok_or_else(|| DsperseError::Pipeline(format!("{slice_id}: template has no graph")))?; let tmpl_init_names: std::collections::HashSet<&str> = tmpl_graph .initializer .iter() .map(|i| i.name.as_str()) .collect(); let input_names: Vec = tmpl_graph .input .iter() .filter(|vi| !tmpl_init_names.contains(vi.name.as_str())) .map(|vi| vi.name.clone()) .collect(); let tmp_dir = tempfile::tempdir() .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: tmpdir: {e}")))?; let tmpl_on_disk = tmp_dir.path().join("dim_tmpl.onnx"); crate::slicer::onnx_proto::save_model(&tmpl_model, &tmpl_on_disk)?; let mut group_outputs: Vec> = Vec::new(); for g in 0..ds.num_groups { let dim_start = g * epg; if dim_start >= ds.dim_size { break; } let dim_end = ((g + 1) * epg).min(ds.dim_size); let actual_size = dim_end - dim_start; // dim_size is required to be an exact multiple of epg by the // detector (`smallest_divisor_at_least`), so every group is // exactly `epg` wide and we can feed the sliced view straight // in -- no zero-padding, no output trimming, no risk of // contaminating reductions on non-split axes. debug_assert_eq!( actual_size, epg, "dim-split detector must enforce dim_size % epg == 0" ); let mut group_cache = TensorStore::new(); for vi_name in &input_names { let arr = tensor_cache.try_get(vi_name).ok_or_else(|| { DsperseError::Pipeline(format!( "{slice_id}: template input {vi_name:?} not found in tensor cache" )) })?; let shape = arr.shape(); if split_dim < shape.len() && shape[split_dim] == ds.dim_size { let sliced = arr .slice_axis(Axis(split_dim), ndarray::Slice::from(dim_start..dim_end)) .to_owned(); group_cache.put(vi_name.clone(), sliced); } else { group_cache.put(vi_name.clone(), arr.clone()); } } let mut named = run_onnx_inference_multi_named(&tmpl_on_disk, &group_cache, &input_names)?; let (data, shape) = named.remove(&ds.output_name).ok_or_else(|| { DsperseError::Pipeline(format!( "{slice_id}: dim-split group {g} missing output {:?} (available: {:?})", ds.output_name, named.keys().collect::>() )) })?; let group_output = ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shape), data) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: group {g} reshape: {e}")))?; // Output is naturally `epg` wide along concat_axis. let trimmed = group_output; group_outputs.push(trimmed); } let result = ndarray::concatenate( Axis(concat_axis), &group_outputs.iter().map(|a| a.view()).collect::>(), ) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: dim-split concat: {e}")))?; let output_shape_vec = if let Some(target) = target_shape { target .iter() .map(|&d| { usize::try_from(d).map_err(|_| { DsperseError::Pipeline(format!( "{slice_id}: invalid target dim {d} in dim-split reshape" )) }) }) .collect::>>()? } else { result.shape().to_vec() }; let final_result = result .as_standard_layout() .into_owned() .into_shape_with_order(ndarray::IxDyn(&output_shape_vec)) .map_err(|e| DsperseError::Pipeline(format!("{slice_id}: dim-split reshape: {e}")))?; tracing::info!( slice = %slice_id, groups = ds.num_groups, split_kind = ?ds.split_kind, "executed dim-split (generic)" ); Ok(final_result) } fn resolve_output_shape( slice_id: &str, input_shape: &[usize], n_dim: usize, target_shape: Option<&[i64]>, ) -> Result> { if let Some(target) = target_shape { target .iter() .map(|&d| { usize::try_from(d).map_err(|_| { DsperseError::Pipeline(format!( "{slice_id}: invalid target dimension {d} in dim-split reshape" )) }) }) .collect::>>() } else { let mut s = input_shape.to_vec(); if let Some(last) = s.last_mut() { *last = n_dim; } Ok(s) } } ================================================ FILE: crates/dsperse/src/pipeline/incremental.rs ================================================ use std::path::{Path, PathBuf}; use ndarray::ArrayD; use crate::error::{DsperseError, Result}; use crate::schema::execution::{ExecutionChain, ExecutionInfo, ExecutionResultEntry, RunMetadata}; use crate::schema::metadata::{BackendKind, ModelMetadata, RunSliceMetadata}; use crate::schema::tiling::{ChannelSplitInfo, TilingInfo}; use super::runner::{build_execution_chain, build_run_metadata, load_model_metadata}; use super::strategy::ExecutionStrategy; use super::tensor_store::TensorStore; pub struct SliceWork { pub slice_id: String, pub input: ArrayD, pub named_inputs: Vec<(String, ArrayD)>, pub backend: BackendKind, pub use_circuit: bool, pub tiling: Option, pub channel_split: Option, pub circuit_path: Option, pub onnx_path: Option, pub slice_meta: RunSliceMetadata, } pub struct SliceExecutionResult { pub slice_id: String, pub output: ArrayD, pub execution_info: ExecutionInfo, } pub struct IncrementalRun { tensor_cache: TensorStore, execution_chain: ExecutionChain, model_meta: ModelMetadata, run_meta: RunMetadata, slices_dir: PathBuf, current_slice: Option, results: Vec, } impl IncrementalRun { pub fn new(slices_dir: &Path, input: ArrayD) -> Result { let model_meta = load_model_metadata(slices_dir)?; let chain = build_execution_chain(&model_meta, slices_dir)?; let run_meta = build_run_metadata(&model_meta, slices_dir, &chain)?; let first_slice = model_meta .slices .first() .ok_or_else(|| DsperseError::Pipeline("model has no slices".into()))?; let filtered = &first_slice.dependencies.filtered_inputs; if filtered.len() != 1 { return Err(DsperseError::Pipeline(format!( "multi-input models not supported: first slice declares {} filtered inputs", filtered.len() ))); } let input_name = filtered[0].clone(); let mut tensor_cache = TensorStore::new(); tensor_cache.put(input_name, input); let current_slice = chain.head.clone(); Ok(Self { tensor_cache, execution_chain: chain, model_meta, run_meta, slices_dir: slices_dir.to_path_buf(), current_slice, results: Vec::new(), }) } pub fn next_slice(&self) -> Result> { let slice_id = match self.current_slice.as_ref() { Some(id) => id, None => return Ok(None), }; let node = self.execution_chain.nodes.get(slice_id).ok_or_else(|| { DsperseError::Pipeline(format!("execution chain missing node for {slice_id}")) })?; let meta = self.run_meta.slices.get(slice_id).ok_or_else(|| { DsperseError::Pipeline(format!("run metadata missing slice {slice_id}")) })?; let strategy = ExecutionStrategy::from_metadata(meta, node.use_circuit)?; let (input, named_inputs) = match strategy { ExecutionStrategy::ChannelSplit(cs) => { let t = self.tensor_cache.get(&cs.input_name)?.clone(); (t, Vec::new()) } ExecutionStrategy::DimSplit(ds) => { let t = self.tensor_cache.get(&ds.input_name)?.clone(); (t, Vec::new()) } ExecutionStrategy::Tiled(tiling) => { let t = self.tensor_cache.get(&tiling.input_name)?.clone(); (t, Vec::new()) } ExecutionStrategy::Single { .. } => { let filtered = &meta.dependencies.filtered_inputs; let mut named = Vec::with_capacity(filtered.len()); for name in filtered { let arr = self.tensor_cache.get(name)?; named.push((name.clone(), arr.clone())); } let concatenated = self.tensor_cache.gather(filtered)?; (concatenated, named) } }; Ok(Some(SliceWork { slice_id: slice_id.clone(), input, named_inputs, backend: node.backend, use_circuit: node.use_circuit, tiling: meta.tiling.clone(), channel_split: meta.channel_split.clone(), circuit_path: node.circuit_path.clone(), onnx_path: node.onnx_path.clone(), slice_meta: meta.clone(), })) } pub fn apply_result(&mut self, result: SliceExecutionResult) -> Result<()> { let slice_id = &result.slice_id; match self.current_slice.as_deref() { Some(expected) if expected != slice_id => { return Err(DsperseError::Pipeline(format!( "out-of-order result: expected {expected}, got {slice_id}" ))); } None => { return Err(DsperseError::Pipeline(format!( "pipeline already complete, unexpected result for {slice_id}" ))); } _ => {} } let meta = self .run_meta .slices .get(slice_id) .ok_or_else(|| DsperseError::Pipeline(format!("unknown slice {slice_id}")))?; let strategy = ExecutionStrategy::from_metadata(meta, false)?; match strategy { ExecutionStrategy::ChannelSplit(cs) => { self.tensor_cache.put(cs.output_name.clone(), result.output); } ExecutionStrategy::DimSplit(ds) => { self.tensor_cache.put(ds.output_name.clone(), result.output); } ExecutionStrategy::Tiled(tiling) => { self.tensor_cache .put(tiling.output_name.clone(), result.output); } ExecutionStrategy::Single { .. } => { if meta.dependencies.output.is_empty() { return Err(DsperseError::Pipeline(format!( "slice {slice_id} has no output dependency names" ))); } for name in &meta.dependencies.output { self.tensor_cache.put(name.clone(), result.output.clone()); } } } self.results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(result.execution_info), proof_execution: None, verification_execution: None, }); let next = self .execution_chain .nodes .get(slice_id) .and_then(|n| n.next.clone()); self.current_slice = next; Ok(()) } pub fn is_complete(&self) -> bool { self.current_slice.is_none() } pub fn final_output(&self) -> Option<&ArrayD> { let last_slice = self.model_meta.slices.last()?; let slice_id = format!("slice_{}", last_slice.index); let meta = self.run_meta.slices.get(&slice_id)?; let strategy = ExecutionStrategy::from_metadata(meta, false).ok()?; match strategy.output_name() { Some(name) => self.tensor_cache.try_get(name), None => { let output_name = meta.dependencies.output.first()?; self.tensor_cache.try_get(output_name) } } } pub fn into_run_metadata(self) -> RunMetadata { let mut meta = self.run_meta; meta.execution_chain.execution_results = self.results; meta.source_path = Some(self.slices_dir.to_string_lossy().into_owned()); meta } pub fn slices_dir(&self) -> &Path { &self.slices_dir } pub fn model_meta(&self) -> &ModelMetadata { &self.model_meta } pub fn run_meta(&self) -> &RunMetadata { &self.run_meta } pub fn tensor_cache(&self) -> &TensorStore { &self.tensor_cache } } ================================================ FILE: crates/dsperse/src/pipeline/mod.rs ================================================ mod channel_split; mod combined; mod compiler; mod dim_split; mod incremental; pub mod packager; mod prover; pub mod publisher; pub mod runner; pub mod slice_cache; mod stage; pub mod strategy; pub mod tensor_store; pub mod tile_executor; mod tiled; mod verifier; pub use combined::CombinedRun; pub use compiler::{ CompileReport, HolographicSetupReport, SliceAnalysisReport, analyze_slices, compile_slices, setup_holographic_for_slices, }; pub use incremental::{IncrementalRun, SliceExecutionResult, SliceWork}; pub use prover::prove_run; pub use runner::{RunConfig, extract_onnx_initializers, run_inference}; pub use slice_cache::SliceAssets; pub use strategy::ExecutionStrategy; pub use tensor_store::TensorStore; pub use tiled::{reconstruct_from_tiles, split_for_tiling, split_into_tiles}; pub use verifier::verify_run; ================================================ FILE: crates/dsperse/src/pipeline/packager.rs ================================================ use std::collections::HashSet; use std::fs; use std::io::Read; use std::path::{Path, PathBuf}; use serde::Serialize; use sha2::{Digest, Sha256}; use walkdir::WalkDir; use crate::error::{DsperseError, Result}; use crate::pipeline::compiler::compute_bundle_signature; use crate::pipeline::runner::load_model_metadata; use crate::schema::metadata::SliceMetadata; use crate::utils::paths::resolve_relative_path; pub struct PackageConfig { pub output_dir: PathBuf, pub author: Option, pub model_version: Option, pub model_name: Option, pub timeout: Option, pub curve: Option, } #[derive(Debug)] pub struct PackageResult { pub component_count: usize, pub wb_count: usize, pub manifest_path: PathBuf, pub total_size: u64, } #[derive(Serialize)] struct ArtifactRef { sha256: String, role: String, filename: String, size_bytes: u64, } #[derive(Serialize)] struct Manifest { version: u32, model: ModelInfo, #[serde(default, skip_serializing_if = "Vec::is_empty")] artifacts: Vec, components: Vec, dag: Vec, } #[derive(Serialize)] struct ModelInfo { name: String, #[serde(skip_serializing_if = "Option::is_none")] curve: Option, #[serde(skip_serializing_if = "Option::is_none")] author: Option, #[serde(skip_serializing_if = "Option::is_none")] version: Option, #[serde(skip_serializing_if = "Option::is_none")] timeout: Option, input_schema: InputSchema, #[serde(skip_serializing_if = "Option::is_none")] dsperse_version: Option, #[serde(skip_serializing_if = "Option::is_none")] jstprove_version: Option, } #[derive(Serialize)] struct InputSchema { shape: Vec>, output_shapes: Vec>, output_names: Vec, } #[derive(Serialize)] struct ComponentEntry { index: usize, name: String, sha256: String, #[serde(skip_serializing_if = "Option::is_none")] curve: Option, #[serde(skip_serializing_if = "Option::is_none")] proof_system: Option, files: Vec, weights: Vec, } #[derive(Serialize)] struct WeightRef { sha256: String, role: String, filename: String, size_bytes: u64, } #[derive(Serialize)] struct DagNode { component_index: usize, inputs: Vec, outputs: Vec, input_shape: Vec>, output_shape: Vec>, } const VALID_CURVES: &[&str] = &[ "bn254", "goldilocks", "goldilocks_basefold", "goldilocks_ext2", "goldilocks_whir", "goldilocks_whir_pq", ]; fn normalize_curve(curve: Option<&str>) -> Result> { let Some(c) = curve else { return Ok(None) }; let c = c.trim().to_ascii_lowercase(); if c.is_empty() { return Err(DsperseError::Other("curve must not be empty".into())); } if !VALID_CURVES.contains(&c.as_str()) { return Err(DsperseError::Other(format!( "unsupported curve {c:?}; expected one of: {}", VALID_CURVES.join(", ") ))); } Ok(Some(c)) } pub fn package_content_addressed( slices_dir: &Path, config: &PackageConfig, ) -> Result { if !slices_dir.is_dir() { return Err(DsperseError::Other(format!( "slices directory not found: {}", slices_dir.display() ))); } let curve = normalize_curve(config.curve.as_deref())?; let model_meta = load_model_metadata(slices_dir)?; let components_dir = config.output_dir.join("components"); let wb_dir = config.output_dir.join("wb"); fs::create_dir_all(&components_dir).map_err(|e| DsperseError::io(e, &components_dir))?; fs::create_dir_all(&wb_dir).map_err(|e| DsperseError::io(e, &wb_dir))?; let mut components: Vec = Vec::new(); let mut dag_nodes: Vec = Vec::new(); let mut written_components: HashSet = HashSet::new(); let mut written_wbs: HashSet = HashSet::new(); let mut total_size: u64 = 0; for slice in &model_meta.slices { let slice_dir = slices_dir.join(format!("slice_{}", slice.index)); let (component_hash, component_files, proof_system, source) = extract_component(slices_dir, slice, &slice_dir, curve.as_deref())?; if !written_components.contains(&component_hash) { let dest = components_dir.join(&component_hash); fs::create_dir_all(&dest).map_err(|e| DsperseError::io(e, &dest))?; match &source { ComponentSource::CircuitBundle(circuit_dir) => { total_size += copy_files_flat(circuit_dir, &dest)?; } ComponentSource::OnnxFile(onnx_path) => { if let Some(filename) = component_files.first() { let dest_file = dest.join(filename); fs::copy(onnx_path, &dest_file) .map_err(|e| DsperseError::io(e, onnx_path))?; total_size += onnx_path .metadata() .map_err(|e| DsperseError::io(e, onnx_path))? .len(); } } } written_components.insert(component_hash.clone()); } let mut weights: Vec = Vec::new(); let payload_blobs = collect_payload_blobs(slices_dir, slice, &slice_dir)?; for (role, filename, data) in &payload_blobs { let hash = sha256_bytes(data); if !written_wbs.contains(&hash) { let wb_path = wb_dir.join(&hash); fs::write(&wb_path, data).map_err(|e| DsperseError::io(e, &wb_path))?; total_size += data.len() as u64; written_wbs.insert(hash.clone()); } weights.push(WeightRef { sha256: hash, role: role.clone(), filename: filename.clone(), size_bytes: data.len() as u64, }); } components.push(ComponentEntry { index: slice.index, name: format!("slice_{}", slice.index), sha256: component_hash, curve: curve.clone(), proof_system, files: component_files, weights, }); dag_nodes.push(DagNode { component_index: slice.index, inputs: slice.dependencies.input.clone(), outputs: slice.dependencies.output.clone(), input_shape: slice.shape.tensor_shape.input.clone(), output_shape: slice.shape.tensor_shape.output.clone(), }); if (slice.index + 1) % 50 == 0 { tracing::info!( progress = slice.index + 1, total = model_meta.slices.len(), "packaging slices" ); } } let mut artifacts: Vec = Vec::new(); let model_artifact_files = ["metadata.msgpack", "model.onnx"]; for filename in &model_artifact_files { let src = slices_dir.join(filename); if !src.is_file() { return Err(DsperseError::Other(format!( "required model artifact '{}' not found at {}", filename, src.display() ))); } reject_symlink_path(&src)?; let data = fs::read(&src).map_err(|e| DsperseError::io(e, &src))?; let hash = sha256_bytes(&data); if !written_wbs.contains(&hash) { let wb_path = wb_dir.join(&hash); fs::write(&wb_path, &data).map_err(|e| DsperseError::io(e, &wb_path))?; total_size += data.len() as u64; written_wbs.insert(hash.clone()); } artifacts.push(ArtifactRef { sha256: hash, role: "artifact".to_string(), filename: (*filename).to_string(), size_bytes: data.len() as u64, }); tracing::info!(filename, "packaged model artifact"); } let model_name = config .model_name .clone() .or_else(|| { slices_dir .parent() .and_then(|p| p.file_name()) .and_then(|n| n.to_str()) .map(String::from) }) .unwrap_or_else(|| "unknown".to_string()); let manifest = Manifest { version: 1, model: ModelInfo { name: model_name, curve: curve.clone(), author: config.author.clone(), version: config.model_version.clone(), timeout: config.timeout, input_schema: InputSchema { shape: model_meta.input_shape, output_shapes: model_meta.output_shapes, output_names: model_meta.output_names, }, dsperse_version: model_meta.dsperse_version, jstprove_version: model_meta.jstprove_version, }, artifacts, components, dag: dag_nodes, }; let manifest_path = config.output_dir.join("manifest.msgpack"); let manifest_bytes = rmp_serde::to_vec_named(&manifest) .map_err(|e| DsperseError::Other(format!("failed to serialize manifest: {e}")))?; fs::write(&manifest_path, &manifest_bytes).map_err(|e| DsperseError::io(e, &manifest_path))?; total_size += manifest_bytes.len() as u64; Ok(PackageResult { component_count: written_components.len(), wb_count: written_wbs.len(), manifest_path, total_size, }) } fn resolve_circuit_dir(slices_dir: &Path, slice: &SliceMetadata) -> Result> { let bundle = slices_dir .join(format!("slice_{}", slice.index)) .join("jstprove") .join("circuit.bundle"); if bundle.is_dir() { return Ok(Some(bundle)); } if let Some(ref cs) = slice.channel_split && let Some(group) = cs.groups.first() && let Some(ref circuit_path) = group.jstprove_circuit_path { let abs = resolve_relative_path(slices_dir, circuit_path)?; if abs.is_dir() { return Ok(Some(abs)); } } if let Some(ref ds) = slice.dim_split && let Some(ref circuit_path) = ds.jstprove_circuit_path { let abs = resolve_relative_path(slices_dir, circuit_path)?; if abs.is_dir() { return Ok(Some(abs)); } } Ok(None) } enum ComponentSource { CircuitBundle(PathBuf), OnnxFile(PathBuf), } fn resolve_source_onnx(slices_dir: &Path, slice: &SliceMetadata) -> Result { if let Some(ref cs) = slice.channel_split && let Some(group) = cs.groups.first() { let p = resolve_relative_path(slices_dir, &group.path)?; reject_symlink_path(&p)?; if !p.is_file() { return Err(DsperseError::Other(format!( "slice {} channel group ONNX configured but missing: {}", slice.index, p.display() ))); } return Ok(p); } if let Some(ref ds) = slice.dim_split && let Some(ref tmpl) = ds.template_path { let p = resolve_relative_path(slices_dir, tmpl)?; reject_symlink_path(&p)?; if !p.is_file() { return Err(DsperseError::Other(format!( "slice {} dim-split template configured but missing: {}", slice.index, p.display() ))); } return Ok(p); } if let Some(ref tiling) = slice.tiling && let Some(ref tile) = tiling.tile { let p = resolve_relative_path(slices_dir, &tile.path)?; reject_symlink_path(&p)?; if !p.is_file() { return Err(DsperseError::Other(format!( "slice {} tile ONNX configured but missing: {}", slice.index, p.display() ))); } return Ok(p); } let p = slice.resolve_onnx(slices_dir)?; reject_symlink_path(&p)?; Ok(p) } fn list_bundle_files(dir: &Path) -> Result> { let mut files = Vec::new(); for entry in WalkDir::new(dir) { let entry = entry.map_err(|e| DsperseError::Other(e.to_string()))?; reject_symlink(&entry)?; if entry.file_type().is_file() { let relative = entry .path() .strip_prefix(dir) .map_err(|e| DsperseError::Other(e.to_string()))? .components() .map(|c| match c { std::path::Component::Normal(part) => Ok(part.to_string_lossy().into_owned()), _ => Err(DsperseError::Other( "unexpected non-normal path component in bundle".into(), )), }) .collect::>>()? .join("/"); files.push(relative); } } files.sort(); Ok(files) } fn extract_component( slices_dir: &Path, slice: &SliceMetadata, _slice_dir: &Path, curve: Option<&str>, ) -> Result<(String, Vec, Option, ComponentSource)> { if let Some(dir) = resolve_circuit_dir(slices_dir, slice)? { let onnx_path = resolve_source_onnx(slices_dir, slice)?; let sig = compute_bundle_signature(&onnx_path, curve, &dir)?; let files = list_bundle_files(&dir)?; return Ok(( sig, files, Some("jstprove".to_string()), ComponentSource::CircuitBundle(dir), )); } let onnx_path = slice.resolve_onnx(slices_dir)?; reject_symlink_path(&onnx_path)?; if onnx_path.is_file() { let filename = onnx_path .file_name() .and_then(|n| n.to_str()) .unwrap_or("model.onnx") .to_string(); let hash = hash_named_file(&onnx_path, &filename, curve)?; return Ok(( hash, vec![filename], None, ComponentSource::OnnxFile(onnx_path), )); } Err(DsperseError::Other(format!( "slice {} has no circuit directory or ONNX artifact to package", slice.index ))) } fn collect_payload_blobs( slices_dir: &Path, slice: &SliceMetadata, slice_dir: &Path, ) -> Result)>> { let mut blobs: Vec<(String, String, Vec)> = Vec::new(); let onnx_path = slice.resolve_onnx(slices_dir).unwrap_or_else(|_| { slice_dir .join("payload") .join(format!("slice_{}.onnx", slice.index)) }); reject_symlink_path(&onnx_path)?; if onnx_path.is_file() { let data = fs::read(&onnx_path).map_err(|e| DsperseError::io(e, &onnx_path))?; let filename = onnx_path .file_name() .and_then(|n| n.to_str()) .unwrap_or("model.onnx") .to_string(); blobs.push(("payload".to_string(), filename, data)); } if let Some(ref cs) = slice.channel_split { for group in &cs.groups { let group_path = resolve_relative_path(slices_dir, &group.path)?; reject_symlink_path(&group_path)?; if group_path.is_file() { let data = fs::read(&group_path).map_err(|e| DsperseError::io(e, &group_path))?; let filename = group_path .file_name() .and_then(|n| n.to_str()) .unwrap_or("group.onnx") .to_string(); blobs.push(("channel_group".to_string(), filename, data)); } } if let Some(ref bias_path) = cs.bias_path { let abs = resolve_relative_path(slices_dir, bias_path)?; reject_symlink_path(&abs)?; if abs.is_file() { let data = fs::read(&abs).map_err(|e| DsperseError::io(e, &abs))?; blobs.push(("bias".to_string(), "bias.msgpack".to_string(), data)); } } } Ok(blobs) } fn reject_symlink_path(path: &Path) -> Result<()> { if path .symlink_metadata() .is_ok_and(|m| m.file_type().is_symlink()) { return Err(DsperseError::Other(format!( "symlinked file is not allowed: {}", path.display() ))); } Ok(()) } fn reject_symlink(entry: &walkdir::DirEntry) -> Result<()> { if entry.file_type().is_symlink() { return Err(DsperseError::Other(format!( "symlinked bundle entry is not allowed: {}", entry.path().display() ))); } Ok(()) } fn hash_named_file(path: &Path, filename: &str, curve: Option<&str>) -> Result { let mut hasher = Sha256::new(); if let Some(c) = curve { let c_bytes = c.as_bytes(); hasher.update((c_bytes.len() as u64).to_le_bytes()); hasher.update(c_bytes); } let name_bytes = filename.as_bytes(); hasher.update((name_bytes.len() as u64).to_le_bytes()); hasher.update(name_bytes); let mut file = fs::File::open(path).map_err(|e| DsperseError::io(e, path))?; let file_len = file .metadata() .map_err(|e| DsperseError::io(e, path))? .len(); hasher.update(file_len.to_le_bytes()); let mut buf = [0u8; 8192]; loop { let n = file.read(&mut buf).map_err(|e| DsperseError::io(e, path))?; if n == 0 { break; } hasher.update(&buf[..n]); } Ok(encode_hex(&hasher.finalize())) } fn sha256_bytes(data: &[u8]) -> String { let mut hasher = Sha256::new(); hasher.update(data); encode_hex(&hasher.finalize()) } fn encode_hex(bytes: &[u8]) -> String { let mut s = String::with_capacity(bytes.len() * 2); for b in bytes { use std::fmt::Write; write!(s, "{:02x}", b).unwrap(); } s } fn copy_files_flat(source_dir: &Path, dest_dir: &Path) -> Result { let mut total: u64 = 0; for entry in WalkDir::new(source_dir) { let entry = entry.map_err(|e| DsperseError::Other(e.to_string()))?; reject_symlink(&entry)?; if entry.file_type().is_file() { let relative = entry .path() .strip_prefix(source_dir) .map_err(|e| DsperseError::Other(e.to_string()))?; let dest_path = dest_dir.join(relative); if let Some(parent) = dest_path.parent() { fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?; } fs::copy(entry.path(), &dest_path).map_err(|e| DsperseError::io(e, entry.path()))?; total += entry .path() .metadata() .map_err(|e| DsperseError::io(e, entry.path()))? .len(); } } Ok(total) } #[cfg(test)] mod tests { use super::*; use std::fs; use tempfile::TempDir; use crate::schema::metadata::{ Compilation, Dependencies, ModelMetadata, SliceShapeWrapper, TensorShape, }; use crate::slicer::onnx_proto; fn write_minimal_onnx(path: &Path, input_dim: i64) { let node = onnx_proto::NodeProto { op_type: "Relu".to_string(), input: vec!["x".to_string()], output: vec!["y".to_string()], ..Default::default() }; let graph = onnx_proto::make_graph( "g", vec![node], vec![onnx_proto::make_tensor_value_info("x", 1, &[1, input_dim])], vec![onnx_proto::make_tensor_value_info("y", 1, &[1, input_dim])], vec![], ); let model = onnx_proto::make_model(graph, 13); onnx_proto::save_model(&model, path).unwrap(); } fn create_test_model_metadata(slices_dir: &Path, count: usize) { let mut slices = Vec::new(); for i in 0..count { let slice_dir = slices_dir.join(format!("slice_{}", i)); let payload_dir = slice_dir.join("payload"); fs::create_dir_all(&payload_dir).unwrap(); write_minimal_onnx( &payload_dir.join(format!("slice_{}.onnx", i)), (64 + i) as i64, ); let circuit_dir = slice_dir.join("jstprove").join("circuit.bundle"); fs::create_dir_all(&circuit_dir).unwrap(); fs::write(circuit_dir.join("circuit.bin"), format!("circuit_{}", i)).unwrap(); fs::write( circuit_dir.join("settings.json"), format!("{{\"idx\":{}}}", i), ) .unwrap(); let inputs = if i == 0 { vec!["model_input".to_string()] } else { vec![format!("tensor_{}", i - 1)] }; let outputs = vec![format!("tensor_{}", i)]; slices.push(SliceMetadata { index: i, filename: format!("slice_{}.onnx", i), path: slice_dir.to_string_lossy().to_string(), relative_path: format!("slice_{}/payload/slice_{}.onnx", i, i), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1, 3, 224, 224]], output: vec![vec![1, 64, 112, 112]], }, }, dependencies: Dependencies { input: inputs, output: outputs, filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }); } let meta = ModelMetadata { original_model: "test_model".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1, 3, 224, 224]], output_shapes: vec![vec![1, 1000]], output_names: vec!["output".to_string()], slice_points: (0..count).collect(), slices, dsperse_version: Some("0.0.1-test".to_string()), dsperse_rev: None, jstprove_version: Some("0.1.0-test".to_string()), jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(slices_dir); } fn ensure_test_artifacts(slices_dir: &Path) { let p = slices_dir.join("model.onnx"); if !p.exists() { fs::write(&p, b"fake-onnx-for-test").unwrap(); } } #[test] fn test_content_addressed_output_structure() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 3); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir: output_dir.clone(), author: Some("test-author".to_string()), model_version: Some("1.0.0".to_string()), model_name: Some("test-model".to_string()), timeout: Some(300), curve: None, }; let result = package_content_addressed(&slices_dir, &config).unwrap(); assert_eq!(result.component_count, 3); assert_eq!(result.wb_count, 5); assert!(result.total_size > 0); assert!(output_dir.join("components").is_dir()); assert!(output_dir.join("wb").is_dir()); assert!(output_dir.join("manifest.msgpack").is_file()); let manifest_bytes = fs::read(output_dir.join("manifest.msgpack")).unwrap(); let manifest: serde_json::Value = rmp_serde::from_slice(&manifest_bytes).unwrap(); let arts = manifest["artifacts"].as_array().unwrap(); assert_eq!(arts.len(), 2); let filenames: Vec<&str> = arts.iter().filter_map(|a| a["filename"].as_str()).collect(); assert!(filenames.contains(&"metadata.msgpack")); assert!(filenames.contains(&"model.onnx")); for art in arts { assert_eq!(art["role"].as_str().unwrap(), "artifact"); assert!(art["sha256"].as_str().unwrap().len() == 64); assert!(art["size_bytes"].as_u64().unwrap() > 0); } } #[test] fn test_missing_model_onnx_fails() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 1); fs::remove_file(slices_dir.join("model.onnx")).unwrap(); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir, author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let err = package_content_addressed(&slices_dir, &config).unwrap_err(); assert!(err.to_string().contains("model.onnx")); } #[cfg(unix)] #[test] fn test_symlinked_artifact_rejected() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 1); fs::remove_file(slices_dir.join("model.onnx")).unwrap(); let target = tmp.path().join("evil.bin"); fs::write(&target, b"evil").unwrap(); std::os::unix::fs::symlink(&target, slices_dir.join("model.onnx")).unwrap(); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir, author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let err = package_content_addressed(&slices_dir, &config).unwrap_err(); assert!(err.to_string().contains("symlink")); } #[test] fn test_manifest_structure() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 2); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir: output_dir.clone(), author: Some("test-author".to_string()), model_version: Some("1.0.0".to_string()), model_name: Some("test-model".to_string()), timeout: Some(300), curve: None, }; package_content_addressed(&slices_dir, &config).unwrap(); let manifest: serde_json::Value = rmp_serde::from_slice(&fs::read(output_dir.join("manifest.msgpack")).unwrap()).unwrap(); assert_eq!(manifest["version"], 1); assert_eq!(manifest["model"]["name"], "test-model"); assert_eq!(manifest["model"]["author"], "test-author"); assert_eq!(manifest["model"]["version"], "1.0.0"); assert_eq!(manifest["model"]["timeout"], 300); let components = manifest["components"].as_array().unwrap(); assert_eq!(components.len(), 2); for comp in components { let sha = comp["sha256"].as_str().unwrap(); assert_eq!(sha.len(), 64); assert!(!comp["files"].as_array().unwrap().is_empty()); assert_eq!(comp["proof_system"], "jstprove"); assert!(!comp["weights"].as_array().unwrap().is_empty()); } let dag = manifest["dag"].as_array().unwrap(); assert_eq!(dag.len(), 2); assert_eq!(dag[0]["inputs"][0], "model_input"); assert_eq!(dag[0]["outputs"][0], "tensor_0"); assert_eq!(dag[1]["inputs"][0], "tensor_0"); } #[test] fn test_component_files_exist() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 1); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir: output_dir.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; package_content_addressed(&slices_dir, &config).unwrap(); let manifest: serde_json::Value = rmp_serde::from_slice(&fs::read(output_dir.join("manifest.msgpack")).unwrap()).unwrap(); let comp = &manifest["components"][0]; let sha = comp["sha256"].as_str().unwrap(); let comp_dir = output_dir.join("components").join(sha); assert!(comp_dir.is_dir()); assert!(comp_dir.join("circuit.bin").is_file()); assert!(comp_dir.join("settings.json").is_file()); } #[test] fn test_wb_files_exist() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 1); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir: output_dir.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; package_content_addressed(&slices_dir, &config).unwrap(); let manifest: serde_json::Value = rmp_serde::from_slice(&fs::read(output_dir.join("manifest.msgpack")).unwrap()).unwrap(); let weight = &manifest["components"][0]["weights"][0]; let sha = weight["sha256"].as_str().unwrap(); let wb_path = output_dir.join("wb").join(sha); assert!(wb_path.is_file()); let size = weight["size_bytes"].as_u64().unwrap(); assert_eq!(fs::metadata(&wb_path).unwrap().len(), size); } #[test] fn test_hash_determinism() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 2); let out1 = tmp.path().join("out1"); let out2 = tmp.path().join("out2"); let config1 = PackageConfig { output_dir: out1.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let config2 = PackageConfig { output_dir: out2.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; package_content_addressed(&slices_dir, &config1).unwrap(); package_content_addressed(&slices_dir, &config2).unwrap(); let m1: serde_json::Value = rmp_serde::from_slice(&fs::read(out1.join("manifest.msgpack")).unwrap()).unwrap(); let m2: serde_json::Value = rmp_serde::from_slice(&fs::read(out2.join("manifest.msgpack")).unwrap()).unwrap(); for i in 0..2 { assert_eq!(m1["components"][i]["sha256"], m2["components"][i]["sha256"]); } } #[test] fn test_curve_changes_hash() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 2); let out_none = tmp.path().join("out_none"); let out_bn = tmp.path().join("out_bn"); let out_gl = tmp.path().join("out_gl"); let config_none = PackageConfig { output_dir: out_none.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let config_bn = PackageConfig { output_dir: out_bn.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: Some("bn254".to_string()), }; let config_gl = PackageConfig { output_dir: out_gl.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: Some("goldilocks".to_string()), }; package_content_addressed(&slices_dir, &config_none).unwrap(); package_content_addressed(&slices_dir, &config_bn).unwrap(); package_content_addressed(&slices_dir, &config_gl).unwrap(); let m_none: serde_json::Value = rmp_serde::from_slice(&fs::read(out_none.join("manifest.msgpack")).unwrap()).unwrap(); let m_bn: serde_json::Value = rmp_serde::from_slice(&fs::read(out_bn.join("manifest.msgpack")).unwrap()).unwrap(); let m_gl: serde_json::Value = rmp_serde::from_slice(&fs::read(out_gl.join("manifest.msgpack")).unwrap()).unwrap(); for i in 0..2 { let h_none = m_none["components"][i]["sha256"].as_str().unwrap(); let h_bn = m_bn["components"][i]["sha256"].as_str().unwrap(); let h_gl = m_gl["components"][i]["sha256"].as_str().unwrap(); assert_ne!(h_none, h_bn, "curve=None vs bn254 should differ"); assert_ne!(h_none, h_gl, "curve=None vs goldilocks should differ"); assert_ne!(h_bn, h_gl, "bn254 vs goldilocks should differ"); } } #[test] fn test_curve_changes_hash_uncompiled_onnx() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); let slice_dir = slices_dir.join("slice_0"); let payload_dir = slice_dir.join("payload"); fs::create_dir_all(&payload_dir).unwrap(); fs::write(payload_dir.join("slice_0.onnx"), "onnx_payload").unwrap(); let meta = ModelMetadata { original_model: "test".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1, 3]], output_shapes: vec![vec![1, 3]], output_names: vec!["out".to_string()], slice_points: vec![0], slices: vec![SliceMetadata { index: 0, filename: "slice_0.onnx".to_string(), path: slice_dir.to_string_lossy().to_string(), relative_path: "slice_0/payload/slice_0.onnx".to_string(), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1, 3]], output: vec![vec![1, 3]], }, }, dependencies: Dependencies { input: vec!["in".to_string()], output: vec!["out".to_string()], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }], dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(&slices_dir); let out_none = tmp.path().join("out_none"); let out_bn = tmp.path().join("out_bn"); let out_gl = tmp.path().join("out_gl"); let config_none = PackageConfig { output_dir: out_none.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let config_bn = PackageConfig { output_dir: out_bn.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: Some("bn254".to_string()), }; let config_gl = PackageConfig { output_dir: out_gl.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: Some("goldilocks".to_string()), }; package_content_addressed(&slices_dir, &config_none).unwrap(); package_content_addressed(&slices_dir, &config_bn).unwrap(); package_content_addressed(&slices_dir, &config_gl).unwrap(); let m_none: serde_json::Value = rmp_serde::from_slice(&fs::read(out_none.join("manifest.msgpack")).unwrap()).unwrap(); let m_bn: serde_json::Value = rmp_serde::from_slice(&fs::read(out_bn.join("manifest.msgpack")).unwrap()).unwrap(); let m_gl: serde_json::Value = rmp_serde::from_slice(&fs::read(out_gl.join("manifest.msgpack")).unwrap()).unwrap(); let h_none = m_none["components"][0]["sha256"].as_str().unwrap(); let h_bn = m_bn["components"][0]["sha256"].as_str().unwrap(); let h_gl = m_gl["components"][0]["sha256"].as_str().unwrap(); assert_ne!(h_none, h_bn, "onnx: curve=None vs bn254 should differ"); assert_ne!(h_none, h_gl, "onnx: curve=None vs goldilocks should differ"); assert_ne!(h_bn, h_gl, "onnx: bn254 vs goldilocks should differ"); } #[test] fn test_invalid_curve_rejected() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 1); let config_typo = PackageConfig { output_dir: tmp.path().join("output"), author: None, model_version: None, model_name: None, timeout: None, curve: Some("bm254".to_string()), }; let result = package_content_addressed(&slices_dir, &config_typo); assert!(result.is_err()); let config_empty = PackageConfig { output_dir: tmp.path().join("output2"), author: None, model_version: None, model_name: None, timeout: None, curve: Some("".to_string()), }; let result = package_content_addressed(&slices_dir, &config_empty); assert!(result.is_err()); } #[test] fn test_curve_normalization() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); create_test_model_metadata(&slices_dir, 1); let out1 = tmp.path().join("out1"); let out2 = tmp.path().join("out2"); let out3 = tmp.path().join("out3"); let config1 = PackageConfig { output_dir: out1.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: Some("bn254".to_string()), }; let config2 = PackageConfig { output_dir: out2.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: Some(" bn254 ".to_string()), }; let config3 = PackageConfig { output_dir: out3.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: Some("BN254".to_string()), }; package_content_addressed(&slices_dir, &config1).unwrap(); package_content_addressed(&slices_dir, &config2).unwrap(); package_content_addressed(&slices_dir, &config3).unwrap(); let m1: serde_json::Value = rmp_serde::from_slice(&fs::read(out1.join("manifest.msgpack")).unwrap()).unwrap(); let m2: serde_json::Value = rmp_serde::from_slice(&fs::read(out2.join("manifest.msgpack")).unwrap()).unwrap(); let m3: serde_json::Value = rmp_serde::from_slice(&fs::read(out3.join("manifest.msgpack")).unwrap()).unwrap(); assert_eq!(m1["components"][0]["sha256"], m2["components"][0]["sha256"]); assert_eq!(m1["components"][0]["sha256"], m3["components"][0]["sha256"]); } #[test] fn test_deduplication_shared_circuits() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); let mut slices = Vec::new(); for i in 0..3 { let slice_dir = slices_dir.join(format!("slice_{}", i)); let payload_dir = slice_dir.join("payload"); fs::create_dir_all(&payload_dir).unwrap(); write_minimal_onnx(&payload_dir.join(format!("slice_{}.onnx", i)), 64); let circuit_dir = slice_dir.join("jstprove").join("circuit.bundle"); fs::create_dir_all(&circuit_dir).unwrap(); fs::write(circuit_dir.join("circuit.bin"), "shared_circuit_data").unwrap(); slices.push(SliceMetadata { index: i, filename: format!("slice_{}.onnx", i), path: slice_dir.to_string_lossy().to_string(), relative_path: format!("slice_{}/payload/slice_{}.onnx", i, i), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1, 64]], output: vec![vec![1, 64]], }, }, dependencies: Dependencies { input: vec![format!("t_{}", i)], output: vec![format!("t_{}", i + 1)], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }); } let meta = ModelMetadata { original_model: "shared_test".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1, 64]], output_shapes: vec![vec![1, 64]], output_names: vec!["out".to_string()], slice_points: vec![0, 1, 2], slices, dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(&slices_dir); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir: output_dir.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let result = package_content_addressed(&slices_dir, &config).unwrap(); assert_eq!(result.component_count, 1); assert_eq!(result.wb_count, 3); let manifest: serde_json::Value = rmp_serde::from_slice(&fs::read(output_dir.join("manifest.msgpack")).unwrap()).unwrap(); let components = manifest["components"].as_array().unwrap(); let hash0 = components[0]["sha256"].as_str().unwrap(); let hash1 = components[1]["sha256"].as_str().unwrap(); let hash2 = components[2]["sha256"].as_str().unwrap(); assert_eq!(hash0, hash1); assert_eq!(hash1, hash2); } #[test] fn test_uncompiled_onnx_only_slice() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); let slice_dir = slices_dir.join("slice_0"); let payload_dir = slice_dir.join("payload"); fs::create_dir_all(&payload_dir).unwrap(); fs::write(payload_dir.join("slice_0.onnx"), "onnx_payload_data").unwrap(); let meta = ModelMetadata { original_model: "test".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1, 3, 224, 224]], output_shapes: vec![vec![1, 1000]], output_names: vec!["output".to_string()], slice_points: vec![0], slices: vec![SliceMetadata { index: 0, filename: "slice_0.onnx".to_string(), path: slice_dir.to_string_lossy().to_string(), relative_path: "slice_0/payload/slice_0.onnx".to_string(), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1, 3, 224, 224]], output: vec![vec![1, 1000]], }, }, dependencies: Dependencies { input: vec!["input".to_string()], output: vec!["output".to_string()], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }], dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(&slices_dir); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir: output_dir.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let result = package_content_addressed(&slices_dir, &config).unwrap(); assert_eq!(result.component_count, 1); let manifest: serde_json::Value = rmp_serde::from_slice(&fs::read(output_dir.join("manifest.msgpack")).unwrap()).unwrap(); let comp = &manifest["components"][0]; assert!(comp["proof_system"].is_null()); let sha = comp["sha256"].as_str().unwrap(); let files = comp["files"].as_array().unwrap(); assert_eq!(files.len(), 1); assert_eq!(files[0], "slice_0.onnx"); let comp_dir = output_dir.join("components").join(sha); assert!(comp_dir.join("slice_0.onnx").is_file()); } #[test] fn test_missing_artifact_errors() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); let slice_dir = slices_dir.join("slice_0"); fs::create_dir_all(&slice_dir).unwrap(); let meta = ModelMetadata { original_model: "test".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1]], output_shapes: vec![vec![1]], output_names: vec!["out".to_string()], slice_points: vec![0], slices: vec![SliceMetadata { index: 0, filename: "slice_0.onnx".to_string(), path: slice_dir.to_string_lossy().to_string(), relative_path: "slice_0/payload/slice_0.onnx".to_string(), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1]], output: vec![vec![1]], }, }, dependencies: Dependencies { input: vec!["in".to_string()], output: vec!["out".to_string()], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }], dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(&slices_dir); let config = PackageConfig { output_dir: tmp.path().join("output"), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let result = package_content_addressed(&slices_dir, &config); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!( err.contains("no circuit directory or ONNX artifact"), "unexpected error: {err}" ); } #[test] fn test_path_traversal_rejected() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); let slice_dir = slices_dir.join("slice_0"); let payload_dir = slice_dir.join("payload"); fs::create_dir_all(&payload_dir).unwrap(); fs::write(payload_dir.join("slice_0.onnx"), "data").unwrap(); let meta = ModelMetadata { original_model: "test".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1]], output_shapes: vec![vec![1]], output_names: vec!["out".to_string()], slice_points: vec![0], slices: vec![SliceMetadata { index: 0, filename: "slice_0.onnx".to_string(), path: slice_dir.to_string_lossy().to_string(), relative_path: "../../etc/passwd".to_string(), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1]], output: vec![vec![1]], }, }, dependencies: Dependencies { input: vec!["in".to_string()], output: vec!["out".to_string()], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }], dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(&slices_dir); let config = PackageConfig { output_dir: tmp.path().join("output"), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let result = package_content_addressed(&slices_dir, &config); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!( err.contains("path traversal"), "expected path traversal error, got: {err}" ); } #[test] fn test_nonexistent_dir() { let config = PackageConfig { output_dir: PathBuf::from("/tmp/nonexistent_output"), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let result = package_content_addressed(Path::new("/nonexistent/path"), &config); assert!(result.is_err()); } #[test] fn test_identical_bytes_different_filenames_distinct_hashes() { let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); let identical_data = "identical_onnx_content"; let mut slices = Vec::new(); for i in 0..2 { let slice_dir = slices_dir.join(format!("slice_{}", i)); let payload_dir = slice_dir.join("payload"); fs::create_dir_all(&payload_dir).unwrap(); fs::write( payload_dir.join(format!("slice_{}.onnx", i)), identical_data, ) .unwrap(); slices.push(SliceMetadata { index: i, filename: format!("slice_{}.onnx", i), path: slice_dir.to_string_lossy().to_string(), relative_path: format!("slice_{}/payload/slice_{}.onnx", i, i), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1]], output: vec![vec![1]], }, }, dependencies: Dependencies { input: vec![format!("t_{}", i)], output: vec![format!("t_{}", i + 1)], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }); } let meta = ModelMetadata { original_model: "test".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1]], output_shapes: vec![vec![1]], output_names: vec!["out".to_string()], slice_points: vec![0, 1], slices, dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(&slices_dir); let output_dir = tmp.path().join("output"); let config = PackageConfig { output_dir: output_dir.clone(), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let result = package_content_addressed(&slices_dir, &config).unwrap(); assert_eq!(result.component_count, 2); let manifest: serde_json::Value = rmp_serde::from_slice(&fs::read(output_dir.join("manifest.msgpack")).unwrap()).unwrap(); let c0 = &manifest["components"][0]; let c1 = &manifest["components"][1]; assert_ne!(c0["sha256"], c1["sha256"]); let dir0 = output_dir .join("components") .join(c0["sha256"].as_str().unwrap()); let dir1 = output_dir .join("components") .join(c1["sha256"].as_str().unwrap()); assert!(dir0.join("slice_0.onnx").is_file()); assert!(dir1.join("slice_1.onnx").is_file()); } #[test] #[cfg(unix)] fn test_symlink_payload_rejected() { use std::os::unix::fs::symlink; let tmp = TempDir::new().unwrap(); let slices_dir = tmp.path().join("model").join("slices"); fs::create_dir_all(&slices_dir).unwrap(); let external = tmp.path().join("external_secret.bin"); fs::write(&external, "sensitive data").unwrap(); let slice_dir = slices_dir.join("slice_0"); let payload_dir = slice_dir.join("payload"); fs::create_dir_all(&payload_dir).unwrap(); symlink(&external, payload_dir.join("slice_0.onnx")).unwrap(); let meta = ModelMetadata { original_model: "test".to_string(), model_type: "onnx".to_string(), input_shape: vec![vec![1]], output_shapes: vec![vec![1]], output_names: vec!["out".to_string()], slice_points: vec![0], slices: vec![SliceMetadata { index: 0, filename: "slice_0.onnx".to_string(), path: slice_dir.to_string_lossy().to_string(), relative_path: "slice_0/payload/slice_0.onnx".to_string(), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![vec![1]], output: vec![vec![1]], }, }, dependencies: Dependencies { input: vec!["in".to_string()], output: vec!["out".to_string()], filtered_inputs: vec![], }, tiling: None, channel_split: None, dim_split: None, compilation: Compilation::default(), slice_metadata: None, slice_metadata_relative_path: None, }], dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: None, traced_types: None, original_model_path: None, folded_constant_names: vec![], }; meta.save(&slices_dir.join("metadata.msgpack")).unwrap(); ensure_test_artifacts(&slices_dir); let config = PackageConfig { output_dir: tmp.path().join("output"), author: None, model_version: None, model_name: None, timeout: None, curve: None, }; let result = package_content_addressed(&slices_dir, &config); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!( err.contains("symlink"), "expected symlink error, got: {err}" ); } } ================================================ FILE: crates/dsperse/src/pipeline/prover.rs ================================================ use std::path::Path; use crate::backend::ProofBackend; use crate::error::Result; use crate::schema::execution::RunMetadata; use super::stage::{PipelineStage, run_pipeline_stage}; pub fn prove_run( run_dir: &Path, slices_dir: &Path, backend: &dyn ProofBackend, parallel: usize, ) -> Result { run_pipeline_stage(PipelineStage::Prove, run_dir, slices_dir, backend, parallel) } ================================================ FILE: crates/dsperse/src/pipeline/publisher.rs ================================================ use std::fs; use std::path::Path; use std::time::Duration; use sha2::{Digest, Sha256}; use crate::error::{DsperseError, Result}; const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); const UPLOAD_TIMEOUT: Duration = Duration::from_secs(300); pub struct PublishConfig { pub api_url: String, pub auth_token: String, pub name: String, pub description: String, pub author: String, pub version: String, pub proof_system: String, pub timeout: u64, pub activate: bool, } pub struct PublishResult { pub model_id: String, pub components_uploaded: usize, pub components_skipped: usize, pub weights_uploaded: usize, pub weights_skipped: usize, } fn auth_header(token: &str) -> String { format!("Bearer {token}") } pub fn publish(dir: &Path, config: &PublishConfig) -> Result { let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .map_err(|e| DsperseError::Other(format!("tokio runtime: {e}")))?; rt.block_on(publish_async(dir, config)) } async fn publish_async(dir: &Path, config: &PublishConfig) -> Result { let manifest_path = dir.join("manifest.msgpack"); if !manifest_path.is_file() { return Err(DsperseError::Other(format!( "manifest.msgpack not found in {}", dir.display() ))); } let manifest_bytes = fs::read(&manifest_path).map_err(|e| DsperseError::io(e, &manifest_path))?; let manifest: serde_json::Value = rmp_serde::from_slice(&manifest_bytes) .map_err(|e| DsperseError::Other(format!("failed to parse manifest: {e}")))?; let components = manifest["components"] .as_array() .ok_or_else(|| DsperseError::Other("manifest missing components array".into()))?; let dag = manifest["dag"] .as_array() .ok_or_else(|| DsperseError::Other("manifest missing dag array".into()))?; let client = reqwest::Client::builder() .timeout(REQUEST_TIMEOUT) .build() .map_err(|e| DsperseError::Other(format!("http client: {e}")))?; let api = config.api_url.trim_end_matches('/'); let auth = auth_header(&config.auth_token); let mut components_uploaded = 0usize; let mut components_skipped = 0usize; let mut weights_uploaded = 0usize; let mut weights_skipped = 0usize; for comp in components { let sha = comp["sha256"] .as_str() .ok_or_else(|| DsperseError::Other("component missing sha256".into()))?; let files: Vec = comp["files"] .as_array() .ok_or_else(|| DsperseError::Other("component missing files".into()))? .iter() .filter_map(|v| v.as_str().map(String::from)) .collect(); // Verify the component by probing each file the manifest // expects to live in blob storage, not by asking for a // metadata row. The registry registers the component row // as soon as POST /admin/components returns, but the // per-file PUTs against the pre-signed upload URLs happen // afterwards -- any failure there (timeout, network blip, // interrupted publish process) leaves the row present with // no backing files. A plain GET /components/{sha} sees the // row and reports "exists, skipping", so every subsequent // publish run re-skips the broken component and the model // stays permanently half-uploaded from downstream // consumers' perspective. // // Mirror the byte-level presence check weight-blob uploads // below already use: HEAD each expected file by issuing a // single-byte ranged GET to the blob path. If every file // is present, the component is genuinely done and we skip. // If every file is missing, proceed to the normal // register + upload path. If the set is partially present // (registered but mid-upload), surface an actionable error // instead of silently continuing, because re-registering // via POST /admin/components will 409 and the current flow // has no way to request fresh upload URLs for that sha. // A manifest entry with zero files is malformed -- the // empty-list case would otherwise make both // `missing.is_empty()` and `present == files.len()` true // below, silently classifying the component as present // without any actual bytes verified. Fail loud. if files.is_empty() { return Err(DsperseError::Other(format!( "component {sha} has no files listed in the manifest; refusing to treat as present" ))); } let mut present = 0usize; let mut missing: Vec = Vec::new(); for filename in &files { let file_url = format!("{api}/components/{sha}/files/{filename}"); let probe = client .get(&file_url) .header("Range", "bytes=0-0") .send() .await .map_err(|e| DsperseError::Other(format!("probe {sha}/{filename}: {e}")))?; // A Range: bytes=0-0 GET against a blob path has two // legitimate success replies: 206 (partial content, // what the blob store returns when it honours the // range) and 200 (full content, what it returns when // it ignores the range for an empty body or tiny file). // Any other 2xx (201 Created, 202 Accepted, 204 No // Content) is ambiguous for a GET on a CAS path and // should not be interpreted as "file present". let status = probe.status(); match status.as_u16() { 200 | 206 => present += 1, 404 => missing.push(filename.clone()), _ => { let text = probe.text().await.unwrap_or_default(); return Err(DsperseError::Other(format!( "probe component {sha}/{filename} returned unexpected status ({status}): {text}" ))); } } } if missing.is_empty() && present == files.len() { tracing::info!(sha = %sha, "component files present, skipping"); components_skipped += 1; continue; } if present > 0 { return Err(DsperseError::Other(format!( "component {sha} is partially uploaded: {present}/{} files present, \ missing: {:?}. A previous publish registered the component row but \ some PUTs did not complete. Run \ `curl -X DELETE -H 'Authorization: Bearer $REGISTRY_AUTH_TOKEN' \ {api}/admin/components/{sha}` to drop the stale row, then re-run \ publish so the full register + upload flow can replay for this sha.", files.len(), missing ))); } let proof_system = comp["proof_system"] .as_str() .unwrap_or(&config.proof_system) .to_uppercase(); let comp_name = comp["name"].as_str().unwrap_or(sha); tracing::info!(sha = %sha, files = files.len(), "registering component"); let register_resp = client .post(format!("{api}/admin/components")) .header("Authorization", &auth) .json(&serde_json::json!({ "sha256": sha, "name": comp_name, "description": "", "proof_system": proof_system, "files": files, })) .send() .await .map_err(|e| DsperseError::Other(format!("register component {sha}: {e}")))?; let reg_status = register_resp.status(); if reg_status.as_u16() == 409 { tracing::info!(sha = %sha, "component already registered (conflict)"); components_skipped += 1; continue; } if !reg_status.is_success() { let text = register_resp.text().await.unwrap_or_default(); if text.contains("already exists") { tracing::info!(sha = %sha, "component already registered"); components_skipped += 1; continue; } return Err(DsperseError::Other(format!( "register component {sha} failed ({reg_status}): {text}" ))); } let resp_body: serde_json::Value = register_resp .json() .await .map_err(|e| DsperseError::Other(format!("parse component response: {e}")))?; let upload_urls = resp_body["upload_urls"] .as_object() .ok_or_else(|| DsperseError::Other("missing upload_urls for component".into()))?; let comp_dir = dir.join("components").join(sha); for (filename, url_val) in upload_urls { let url = url_val .as_str() .ok_or_else(|| DsperseError::Other(format!("non-string URL for {filename}")))?; let file_path = comp_dir.join(filename); let data = fs::read(&file_path).map_err(|e| DsperseError::io(e, &file_path))?; tracing::info!(file = %filename, size = data.len(), "uploading component file"); let put = client .put(url) .timeout(UPLOAD_TIMEOUT) .header("Content-Type", "application/octet-stream") .body(data) .send() .await .map_err(|e| DsperseError::Other(format!("upload {filename}: {e}")))?; if !put.status().is_success() { return Err(DsperseError::Other(format!( "upload component file {filename} failed ({})", put.status() ))); } } components_uploaded += 1; } let mut all_weight_refs: Vec<&serde_json::Value> = Vec::new(); if let Some(artifacts) = manifest["artifacts"].as_array() { all_weight_refs.extend(artifacts); } for comp in components { if let Some(weights) = comp["weights"].as_array() { all_weight_refs.extend(weights); } } let mut uploaded_wbs: std::collections::HashSet = std::collections::HashSet::new(); for wref in &all_weight_refs { let sha = wref["sha256"] .as_str() .ok_or_else(|| DsperseError::Other("weight ref missing sha256".into()))?; if uploaded_wbs.contains(sha) { continue; } let size = wref["size_bytes"].as_u64().unwrap_or(0); let check = client .get(format!("{api}/models/wb/{sha}")) .header("Range", "bytes=0-0") .send() .await .map_err(|e| DsperseError::Other(format!("check wb {sha}: {e}")))?; if check.status().is_success() || check.status().as_u16() == 206 { tracing::info!(sha = %sha, "weight blob exists, skipping"); weights_skipped += 1; uploaded_wbs.insert(sha.to_string()); continue; } if check.status().as_u16() != 404 { let status = check.status(); let text = check.text().await.unwrap_or_default(); return Err(DsperseError::Other(format!( "probe wb {sha} returned unexpected status ({status}): {text}" ))); } let name = wref["role"].as_str().unwrap_or(""); tracing::info!(sha = %sha, size, "registering weight blob"); let wb_resp = client .post(format!("{api}/admin/models/wb")) .header("Authorization", &auth) .json(&serde_json::json!({ "sha256": sha, "name": name, "size_bytes": size, })) .send() .await .map_err(|e| DsperseError::Other(format!("register wb {sha}: {e}")))?; let wb_status = wb_resp.status(); if wb_status.as_u16() == 409 { tracing::info!(sha = %sha, "weight blob already registered (conflict)"); weights_skipped += 1; uploaded_wbs.insert(sha.to_string()); continue; } if !wb_status.is_success() { let text = wb_resp.text().await.unwrap_or_default(); if text.contains("already exists") { tracing::info!(sha = %sha, "weight blob already registered"); weights_skipped += 1; uploaded_wbs.insert(sha.to_string()); continue; } return Err(DsperseError::Other(format!( "register wb {sha} failed ({wb_status}): {text}" ))); } let wb_body: serde_json::Value = wb_resp .json() .await .map_err(|e| DsperseError::Other(format!("parse wb response: {e}")))?; match wb_body["upload_url"].as_str() { Some(upload_url) => { let wb_path = dir.join("wb").join(sha); let data = fs::read(&wb_path).map_err(|e| DsperseError::io(e, &wb_path))?; tracing::info!(sha = %sha, size = data.len(), "uploading weight blob"); let put = client .put(upload_url) .timeout(UPLOAD_TIMEOUT) .header("Content-Type", "application/octet-stream") .body(data) .send() .await .map_err(|e| DsperseError::Other(format!("upload wb {sha}: {e}")))?; if !put.status().is_success() { return Err(DsperseError::Other(format!( "upload wb {sha} failed ({})", put.status() ))); } weights_uploaded += 1; } None => { return Err(DsperseError::Other(format!( "registry returned no upload URL for weight blob {sha}" ))); } } uploaded_wbs.insert(sha.to_string()); } let model_info = &manifest["model"]; let model_name = model_info["name"].as_str().unwrap_or(config.name.as_str()); let model_author = model_info["author"] .as_str() .unwrap_or(config.author.as_str()); let model_version = model_info["version"] .as_str() .unwrap_or(config.version.as_str()); let model_timeout = model_info["timeout"].as_u64().unwrap_or(config.timeout); let input_schema = &model_info["input_schema"]; let dsperse_version = model_info["dsperse_version"].as_str(); let jstprove_version = model_info["jstprove_version"].as_str(); let artifacts = manifest["artifacts"] .as_array() .cloned() .unwrap_or_default(); let composition = serde_json::json!({ "version": 1, "artifacts": artifacts, "components": components, "dag": dag, }); let mut model_hasher = Sha256::new(); model_hasher.update(model_name.as_bytes()); model_hasher.update(b"\x00"); model_hasher.update(model_author.as_bytes()); model_hasher.update(b"\x00"); model_hasher.update(model_version.as_bytes()); model_hasher.update(b"\x00"); model_hasher.update(model_timeout.to_le_bytes()); model_hasher.update(b"\x00"); let comp_json = serde_json::to_string(&composition) .map_err(|e| DsperseError::Other(format!("serialize composition: {e}")))?; model_hasher.update(comp_json.as_bytes()); let model_id = format!("{:x}", model_hasher.finalize()); tracing::info!(id = %model_id, "creating model"); let model_resp = client .post(format!("{api}/admin/models")) .header("Authorization", &auth) .json(&serde_json::json!({ "id": model_id, "metadata": { "name": model_name, "description": config.description, "author": model_author, "version": model_version, "netuid": null, "weights_version": null, "timeout": model_timeout, "input_schema": input_schema, "dsperse_version": dsperse_version, "jstprove_version": jstprove_version, }, "composition": composition, })) .send() .await .map_err(|e| DsperseError::Other(format!("create model: {e}")))?; if !model_resp.status().is_success() { let status = model_resp.status(); let text = model_resp.text().await.unwrap_or_default(); if !text.contains("already exists") { return Err(DsperseError::Other(format!( "create model failed ({status}): {text}" ))); } tracing::info!(id = %model_id, "model already exists"); } if config.activate { tracing::info!(id = %model_id, "activating model"); let activate_resp = client .patch(format!("{api}/admin/models/{model_id}")) .header("Authorization", &auth) .json(&serde_json::json!({ "is_active": true })) .send() .await .map_err(|e| DsperseError::Other(format!("activate: {e}")))?; if !activate_resp.status().is_success() { let status = activate_resp.status(); let text = activate_resp.text().await.unwrap_or_default(); return Err(DsperseError::Other(format!( "activate failed ({status}): {text}" ))); } } Ok(PublishResult { model_id, components_uploaded, components_skipped, weights_uploaded, weights_skipped, }) } ================================================ FILE: crates/dsperse/src/pipeline/runner.rs ================================================ use std::collections::HashMap; use std::path::{Path, PathBuf}; use ndarray::{ArrayD, IxDyn}; use jstprove_circuits::api::CircuitParamsType as CircuitParams; use super::strategy::ExecutionStrategy; use super::tensor_store::TensorStore; use crate::backend::jstprove::JstproveBackend; use crate::backend::onnx::NamedOutputs; use crate::error::{DsperseError, Result}; use crate::schema::execution::{ ExecutionChain, ExecutionInfo, ExecutionMethod, ExecutionNode, ExecutionResultEntry, RunMetadata, }; use crate::schema::metadata::{BackendKind, ModelMetadata, RunSliceMetadata}; use crate::slicer::onnx_proto::TensorProto; use crate::utils::io::{ arrayd_to_value, build_msgpack_map, extract_input_data, map_get_ref, read_msgpack, value_to_arrayd, write_msgpack, }; use crate::utils::paths::{find_metadata_path, resolve_relative_path, slice_dir_path}; use rmpv::Value; pub struct RunConfig { pub parallel: usize, pub batch: bool, pub weights_onnx: Option, pub combined: bool, } impl Default for RunConfig { fn default() -> Self { Self { parallel: 1, batch: false, weights_onnx: None, combined: true, } } } fn resolve_circuit_path_required( slices_dir: &Path, circuit_path: Option<&str>, label: &str, ) -> Result { circuit_path .map(|p| resolve_relative_path(slices_dir, p)) .transpose()? .ok_or_else(|| DsperseError::Pipeline(format!("no circuit path for {label}"))) } pub(crate) fn resolve_circuit_path_optional( slices_dir: &Path, circuit_path: Option<&str>, ) -> Result> { circuit_path .map(|p| resolve_relative_path(slices_dir, p)) .transpose() } pub fn load_model_metadata(slices_dir: &Path) -> Result { let meta_path = find_metadata_path(slices_dir).ok_or_else(|| { DsperseError::Metadata(format!( "no {} in slices", crate::utils::paths::METADATA_FILE )) })?; let mut model_meta = ModelMetadata::load(&meta_path)?; if model_meta.slices.is_empty() { return Err(DsperseError::Metadata(format!( "{} has no slices in {}", crate::utils::paths::METADATA_FILE, slices_dir.display() ))); } model_meta.slices.sort_by_key(|s| s.index); Ok(model_meta) } fn validate_weights_onnx( donor_init_map: &HashMap, model_meta: &ModelMetadata, slices_dir: &Path, ) -> Result<()> { for slice in &model_meta.slices { let onnx_path = slice.resolve_onnx(slices_dir)?; if !onnx_path.exists() { return Err(DsperseError::Pipeline(format!( "slice_{} ONNX not found at {}", slice.index, onnx_path.display() ))); } let slice_model = crate::slicer::onnx_proto::load_model(&onnx_path)?; let slice_graph = slice_model.graph.as_ref().ok_or_else(|| { DsperseError::Pipeline(format!( "slice_{} ONNX at {} has no graph", slice.index, onnx_path.display() )) })?; let context = format!("slice_{}", slice.index); crate::slicer::onnx_proto::validate_initializer_compatibility( &slice_graph.initializer, donor_init_map, &context, )?; } Ok(()) } fn load_donor_model( weights_onnx: Option<&PathBuf>, ) -> Result> { let weights_path = match weights_onnx { Some(p) => p, None => return Ok(None), }; if !weights_path.is_file() { return Err(DsperseError::Other(format!( "consumer weights ONNX not found: {}", weights_path.display() ))); } Ok(Some(crate::slicer::onnx_proto::load_model(weights_path)?)) } fn donor_init_map( model: Option<&crate::slicer::onnx_proto::ModelProto>, ) -> Result>> { match model { Some(m) => { let graph = m.graph.as_ref().ok_or_else(|| { DsperseError::Pipeline("consumer weights ONNX missing graph".into()) })?; Ok(Some(crate::slicer::onnx_proto::build_initializer_map( graph, ))) } None => Ok(None), } } pub fn run_inference( slices_dir: &Path, input_path: &Path, run_dir: &Path, backend: &JstproveBackend, config: &RunConfig, ) -> Result { let model_meta = load_model_metadata(slices_dir)?; if config.combined && model_meta.original_model_path.is_some() && model_meta.traced_shapes.is_some() { return run_combined_inference( slices_dir, input_path, run_dir, backend, config, &model_meta, ); } else if config.combined { tracing::warn!( "combined mode requested but metadata missing original_model_path or traced_shapes, using per-slice execution" ); } if model_meta.original_model_path.is_some() { crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, &model_meta)?; } let donor_model = load_donor_model(config.weights_onnx.as_ref())?; let donor_map = donor_init_map(donor_model.as_ref())?; if let Some(ref map) = donor_map { validate_weights_onnx(map, &model_meta, slices_dir)?; tracing::info!( weights = %config.weights_onnx.as_ref().unwrap().display(), "validated consumer weights ONNX" ); } std::fs::create_dir_all(run_dir).map_err(|e| DsperseError::io(e, run_dir))?; let input_data = read_msgpack(input_path)?; let chain = build_execution_chain(&model_meta, slices_dir)?; let run_meta = build_run_metadata(&model_meta, slices_dir, &chain)?; let mut tensor_cache = TensorStore::new(); let input_val = extract_input_data(&input_data).ok_or_else(|| { DsperseError::Pipeline( "input has no recognized input key (input_data, input, data, inputs)".into(), ) })?; let first_slice = model_meta .slices .first() .ok_or_else(|| DsperseError::Pipeline("model has no slices".into()))?; let declared_inputs = &first_slice.dependencies.filtered_inputs; if declared_inputs.is_empty() { return Err(DsperseError::Pipeline( "first slice has no input dependency".into(), )); } if input_val.is_map() { for name in declared_inputs { let v = map_get_ref(input_val, name) .ok_or_else(|| DsperseError::Pipeline(format!("input map missing key {name:?}")))?; tensor_cache.put(name.clone(), value_to_arrayd(v)?); } } else if declared_inputs.len() == 1 { tensor_cache.put(declared_inputs[0].clone(), value_to_arrayd(input_val)?); } else { return Err(DsperseError::Pipeline(format!( "model declares {} inputs but input is not a map", declared_inputs.len() ))); } let input_copy = run_dir.join(crate::utils::paths::INPUT_FILE); write_msgpack(&input_copy, &input_data)?; let mut results: Vec = Vec::new(); let mut current = chain.head.clone(); while let Some(slice_id) = current.take() { let node = chain .nodes .get(&slice_id) .ok_or_else(|| DsperseError::Pipeline(format!("missing node {slice_id}")))?; let slice_meta = run_meta.slices.get(&slice_id).ok_or_else(|| { DsperseError::Pipeline(format!("missing run slice metadata {slice_id}")) })?; let slice_run_dir = run_dir.join(&slice_id); std::fs::create_dir_all(&slice_run_dir).map_err(|e| DsperseError::io(e, &slice_run_dir))?; tracing::info!(slice = %slice_id, circuit = node.use_circuit, "executing"); let exec_result = execute_slice( slices_dir, &slice_run_dir, &slice_id, node, slice_meta, &mut tensor_cache, backend, config, donor_map.as_ref(), ); let exec_info = match exec_result { Ok(info) => info, Err(e) => { tracing::error!(slice = %slice_id, error = %e, "execution failed"); let method = ExecutionStrategy::from_metadata(slice_meta, node.use_circuit) .map(|s| s.execution_method()) .unwrap_or(ExecutionMethod::OnnxOnly); results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(ExecutionInfo { method, success: false, error: Some(e.to_string()), witness_file: None, tile_exec_infos: Vec::new(), }), proof_execution: None, verification_execution: None, }); break; } }; results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(exec_info), proof_execution: None, verification_execution: None, }); current = node.next.clone(); } let mut final_meta = run_meta; final_meta.execution_chain.execution_results = results; final_meta.run_directory = Some(run_dir.to_string_lossy().into_owned()); let meta_out = run_dir.join(crate::utils::paths::METADATA_FILE); crate::utils::metadata::save_run_metadata(&meta_out, &final_meta)?; let last_slice = model_meta .slices .last() .ok_or_else(|| DsperseError::Pipeline("model has no slices".into()))?; let last_slice_id = format!("slice_{}", last_slice.index); if let Some(failed) = final_meta .execution_chain .execution_results .iter() .find(|r| r.witness_execution.as_ref().is_some_and(|w| !w.success)) { let err_msg = failed .witness_execution .as_ref() .and_then(|w| w.error.as_deref()) .unwrap_or("unknown"); return Err(DsperseError::Pipeline(format!( "pipeline failed at {}: {err_msg}", failed.slice_id ))); } let slice_run_meta = final_meta.slices.get(&last_slice_id); let last_strategy = match slice_run_meta { Some(m) => { let use_circuit = final_meta .execution_chain .nodes .get(&last_slice_id) .is_some_and(|n| n.use_circuit); ExecutionStrategy::from_metadata(m, use_circuit).ok() } None => None, }; let output_arrs: Vec<&ArrayD> = { let strategy_output = last_strategy .as_ref() .and_then(|s| s.output_name()) .and_then(|name| tensor_cache.try_get(name)); if let Some(arr) = strategy_output { vec![arr] } else if !model_meta.output_names.is_empty() { let found: Vec<_> = model_meta .output_names .iter() .filter_map(|n| tensor_cache.try_get(n)) .collect(); if found.is_empty() { tracing::warn!( expected = ?model_meta.output_names, available = ?tensor_cache.keys().collect::>(), "none of the declared output_names found in tensor cache" ); } found } else { last_slice .dependencies .output .iter() .find_map(|n| tensor_cache.try_get(n)) .into_iter() .collect() } }; if output_arrs.is_empty() { let first_error = final_meta .execution_chain .execution_results .iter() .filter_map(|r| { r.witness_execution .as_ref() .and_then(|w| w.error.as_deref()) .map(|err| format!("{}: {err}", r.slice_id)) }) .next(); return Err(match first_error { Some(err) => DsperseError::Pipeline(format!("pipeline failed at {err}")), None => DsperseError::Pipeline(format!( "no output tensor found for last slice {last_slice_id}" )), }); } let output_path = run_dir.join(crate::utils::paths::OUTPUT_FILE); let output_val = Value::Array(output_arrs.iter().map(|arr| arrayd_to_value(arr)).collect()); write_msgpack( &output_path, &build_msgpack_map(vec![("output_data", output_val)]), )?; Ok(final_meta) } fn run_combined_inference( slices_dir: &Path, input_path: &Path, run_dir: &Path, backend: &JstproveBackend, config: &RunConfig, model_meta: &ModelMetadata, ) -> Result { let combined_path = crate::slicer::combiner::ensure_combined_materialized(slices_dir, model_meta)?; let donor_model = load_donor_model(config.weights_onnx.as_ref())?; let donor_map = donor_init_map(donor_model.as_ref())?; if let Some(ref map) = donor_map { let combined_model = crate::slicer::onnx_proto::load_model(&combined_path)?; let combined_graph = combined_model .graph .as_ref() .ok_or_else(|| DsperseError::Pipeline("combined ONNX missing graph".into()))?; crate::slicer::onnx_proto::validate_initializer_compatibility( &combined_graph.initializer, map, "combined", )?; tracing::info!( weights = %config.weights_onnx.as_ref().unwrap().display(), "validated consumer weights against combined ONNX" ); } std::fs::create_dir_all(run_dir).map_err(|e| DsperseError::io(e, run_dir))?; let input_data = read_msgpack(input_path)?; let input_val = extract_input_data(&input_data).ok_or_else(|| { DsperseError::Pipeline( "input has no recognized input key (input_data, input, data, inputs)".into(), ) })?; let first_slice = model_meta .slices .first() .ok_or_else(|| DsperseError::Pipeline("model has no slices".into()))?; let declared_inputs = &first_slice.dependencies.filtered_inputs; if declared_inputs.is_empty() { return Err(DsperseError::Pipeline( "first slice has no input dependency".into(), )); } let input_copy = run_dir.join(crate::utils::paths::INPUT_FILE); write_msgpack(&input_copy, &input_data)?; let effective_combined = if let Some(ref map) = donor_map { Some(crate::slicer::onnx_proto::build_patched_onnx( &combined_path, map, )?) } else { None }; let effective_path = effective_combined .as_ref() .map_or(combined_path.as_path(), |t| t.path()); let named_outputs = if input_val.is_map() { let mut cache = TensorStore::new(); for name in declared_inputs { let v = map_get_ref(input_val, name) .ok_or_else(|| DsperseError::Pipeline(format!("input map missing key {name:?}")))?; cache.put(name.clone(), value_to_arrayd(v)?); } let inputs: Vec = declared_inputs.clone(); run_onnx_inference_multi_named(effective_path, &cache, &inputs)? } else if declared_inputs.len() == 1 { let input_arr = value_to_arrayd(input_val)?; run_onnx_inference_named(effective_path, &input_arr)? } else { return Err(DsperseError::Pipeline(format!( "model declares {} inputs but input is not a map", declared_inputs.len() ))); }; tracing::info!( outputs = named_outputs.len(), "combined model inference complete" ); let mut tensor_cache = TensorStore::new(); for (name, (data, shape)) in &named_outputs { let arr = ArrayD::from_shape_vec(IxDyn(shape), data.clone()) .map_err(|e| DsperseError::Pipeline(format!("output reshape '{name}': {e}")))?; tensor_cache.put(name.clone(), arr); } for name in declared_inputs { if !tensor_cache.contains(name) { if input_val.is_map() { let v = map_get_ref(input_val, name).ok_or_else(|| { DsperseError::Pipeline(format!( "combined fallback: input map missing key {name:?}" )) })?; tensor_cache.put(name.clone(), value_to_arrayd(v)?); } else if declared_inputs.len() == 1 { tensor_cache.put(name.clone(), value_to_arrayd(input_val)?); } } } crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, model_meta)?; let chain = build_execution_chain(model_meta, slices_dir)?; let run_meta = build_run_metadata(model_meta, slices_dir, &chain)?; let mut results: Vec = Vec::new(); for slice in &model_meta.slices { let slice_id = format!("slice_{}", slice.index); let node = chain .nodes .get(&slice_id) .ok_or_else(|| DsperseError::Pipeline(format!("missing node {slice_id}")))?; let slice_meta = run_meta.slices.get(&slice_id).ok_or_else(|| { DsperseError::Pipeline(format!("missing run slice metadata {slice_id}")) })?; let slice_run_dir = run_dir.join(&slice_id); std::fs::create_dir_all(&slice_run_dir).map_err(|e| DsperseError::io(e, &slice_run_dir))?; if !node.use_circuit { results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(ExecutionInfo { method: ExecutionMethod::OnnxOnly, success: true, error: None, witness_file: None, tile_exec_infos: Vec::new(), }), proof_execution: None, verification_execution: None, }); continue; } let strategy = ExecutionStrategy::from_metadata(slice_meta, node.use_circuit)?; if let ExecutionStrategy::ChannelSplit(_) = &strategy { return Err(DsperseError::Pipeline(format!( "{slice_id}: combined mode does not support channel-split circuit slices; use --combined false" ))); } if let ExecutionStrategy::DimSplit(_) = &strategy { return Err(DsperseError::Pipeline(format!( "{slice_id}: combined mode does not support dim-split circuit slices; use --combined false" ))); } if let ExecutionStrategy::Tiled(tiling) = &strategy { let result = super::tiled::execute_combined_tiled( slices_dir, &slice_run_dir, &slice_id, tiling, slice_meta.jstprove_circuit_path.as_deref(), &tensor_cache, backend, config, donor_map.as_ref(), )?; for (name, tensor) in result.outputs { tensor_cache.put(name, tensor); } let success = result.info.success; results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(result.info), proof_execution: None, verification_execution: None, }); if !success { break; } continue; } let circuit_path = resolve_circuit_path_required( slices_dir, slice_meta.jstprove_circuit_path.as_deref(), &slice_id, )?; let params = backend.load_params(&circuit_path)?; let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs); if donor_map.is_some() && !is_wai { return Err(DsperseError::Pipeline(format!( "{slice_id}: consumer weights require circuits compiled with --weights-as-inputs" ))); } let activation_inputs: Vec = slice .dependencies .filtered_inputs .iter() .filter(|s| !s.is_empty()) .cloned() .collect(); let witness_result = if activation_inputs.is_empty() { Err(DsperseError::Pipeline(format!( "{slice_id}: no activation inputs declared for circuit slice" ))) } else { let mut flat_activations: Vec = Vec::new(); for input_name in &activation_inputs { let input_arr = tensor_cache.get(input_name).map_err(|_| { DsperseError::Pipeline(format!( "{slice_id}: activation input '{input_name}' not found in combined model outputs" )) })?; flat_activations.extend(input_arr.iter()); } if is_wai { let onnx_path = slice.resolve_onnx(slices_dir)?; let initializers = if let Some(donor) = donor_map.as_ref() { let slice_model = crate::slicer::onnx_proto::load_model(&onnx_path)?; let slice_graph = slice_model.graph.as_ref().ok_or_else(|| { DsperseError::Pipeline(format!("{slice_id}: ONNX missing graph")) })?; let mut merged = crate::slicer::onnx_proto::build_initializer_map(slice_graph); for (k, v) in donor.iter() { merged.insert(k.clone(), *v); } extract_initializers_from_map(&merged, params.as_ref().unwrap())? } else { extract_onnx_initializers(&onnx_path, params.as_ref().unwrap())? }; backend.witness_f64(&circuit_path, &flat_activations, &initializers) } else { backend.witness_f64(&circuit_path, &flat_activations, &[]) } }; match witness_result { Ok(witness_bytes) => { let witness_path = slice_run_dir.join(crate::utils::paths::WITNESS_FILE); std::fs::write(&witness_path, &witness_bytes) .map_err(|e| DsperseError::io(e, &witness_path))?; tracing::info!(slice = %slice_id, "witness generated from combined outputs"); results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(ExecutionInfo { method: ExecutionMethod::JstproveGenWitness, success: true, error: None, witness_file: Some(witness_path.to_string_lossy().into_owned()), tile_exec_infos: Vec::new(), }), proof_execution: None, verification_execution: None, }); } Err(e) => { tracing::error!(slice = %slice_id, error = %e, "witness generation failed"); results.push(ExecutionResultEntry { slice_id: slice_id.clone(), witness_execution: Some(ExecutionInfo { method: ExecutionMethod::JstproveGenWitness, success: false, error: Some(e.to_string()), witness_file: None, tile_exec_infos: Vec::new(), }), proof_execution: None, verification_execution: None, }); break; } } } let mut final_meta = run_meta; final_meta.execution_chain.execution_results = results; final_meta.run_directory = Some(run_dir.to_string_lossy().into_owned()); let witness_failure = final_meta .execution_chain .execution_results .iter() .filter_map(|r| { r.witness_execution .as_ref() .filter(|w| !w.success) .and_then(|w| w.error.as_ref()) .map(|err| format!("{}: {err}", r.slice_id)) }) .next(); if let Some(err) = witness_failure { let meta_out = run_dir.join(crate::utils::paths::METADATA_FILE); let _ = crate::utils::metadata::save_run_metadata(&meta_out, &final_meta); return Err(DsperseError::Pipeline(format!( "combined pipeline failed at {err}" ))); } let meta_out = run_dir.join(crate::utils::paths::METADATA_FILE); crate::utils::metadata::save_run_metadata(&meta_out, &final_meta)?; let last_slice = model_meta .slices .last() .ok_or_else(|| DsperseError::Pipeline("model has no slices".into()))?; let output_arrs: Vec<&ArrayD> = if !model_meta.output_names.is_empty() { model_meta .output_names .iter() .filter_map(|n| tensor_cache.try_get(n)) .collect() } else { last_slice .dependencies .output .iter() .find_map(|n| tensor_cache.try_get(n)) .into_iter() .collect() }; if output_arrs.is_empty() { let expected: Vec<&str> = if !model_meta.output_names.is_empty() { model_meta.output_names.iter().map(String::as_str).collect() } else { last_slice .dependencies .output .iter() .map(String::as_str) .collect() }; let available: Vec<&String> = tensor_cache.keys().collect(); return Err(DsperseError::Pipeline(format!( "no output tensor found in combined model outputs; expected {expected:?}, available {available:?}" ))); } let output_path = run_dir.join(crate::utils::paths::OUTPUT_FILE); let output_val = Value::Array(output_arrs.iter().map(|arr| arrayd_to_value(arr)).collect()); write_msgpack( &output_path, &build_msgpack_map(vec![("output_data", output_val)]), )?; tracing::info!( run_dir = %run_dir.display(), slices = model_meta.slices.len(), "combined inference complete" ); Ok(final_meta) } #[allow(clippy::too_many_arguments)] fn execute_slice( slices_dir: &Path, slice_run_dir: &Path, slice_id: &str, node: &ExecutionNode, meta: &RunSliceMetadata, tensor_cache: &mut TensorStore, backend: &JstproveBackend, config: &RunConfig, donor_init_map: Option<&HashMap>, ) -> Result { let strategy = ExecutionStrategy::from_metadata(meta, node.use_circuit)?; match strategy { ExecutionStrategy::ChannelSplit(cs) => { let target_shape = meta .dependencies .output .iter() .position(|name| name == &cs.output_name) .and_then(|idx| meta.output_shape.get(idx)) .map(|v| v.as_slice()); if target_shape.is_none() { tracing::debug!( slice = %slice_id, output_name = %cs.output_name, "target_shape lookup failed; output will not be reshaped" ); } let result = super::channel_split::execute_channel_split( slices_dir, slice_run_dir, slice_id, cs, target_shape, tensor_cache, backend, donor_init_map, )?; for (name, tensor) in result.outputs { tensor_cache.put(name, tensor); } Ok(result.info) } ExecutionStrategy::Tiled(tiling) => { let slice_circuit = resolve_circuit_path_optional(slices_dir, meta.jstprove_circuit_path.as_deref())?; let result = super::tiled::execute_tiled( slices_dir, slice_run_dir, slice_id, tiling, slice_circuit.as_deref(), tensor_cache, backend, config, donor_init_map, )?; for (name, tensor) in result.outputs { tensor_cache.put(name, tensor); } Ok(result.info) } ExecutionStrategy::DimSplit(ds) => { let target_shape = meta .dependencies .output .iter() .position(|name| name == &ds.output_name) .and_then(|idx| meta.output_shape.get(idx)) .map(|v| v.as_slice()); let result = super::dim_split::execute_dim_split( slices_dir, slice_run_dir, slice_id, ds, target_shape, tensor_cache, backend, donor_init_map, )?; for (name, tensor) in result.outputs { tensor_cache.put(name, tensor); } Ok(result.info) } ExecutionStrategy::Single { .. } => { let result = execute_single( slices_dir, slice_run_dir, slice_id, node, meta, tensor_cache, backend, donor_init_map, )?; for (name, tensor) in result.outputs { tensor_cache.put(name, tensor); } Ok(result.info) } } } #[allow(clippy::too_many_arguments)] fn execute_single( slices_dir: &Path, slice_run_dir: &Path, slice_id: &str, node: &ExecutionNode, meta: &RunSliceMetadata, tensor_cache: &TensorStore, backend: &JstproveBackend, donor_init_map: Option<&HashMap>, ) -> Result { let inputs: Vec = meta .dependencies .filtered_inputs .iter() .filter(|s| !s.is_empty()) .cloned() .collect(); let multi_input = inputs.len() > 1; if inputs.is_empty() { return Err(DsperseError::Pipeline(format!( "{slice_id}: no activation inputs declared" ))); } let onnx_path = PathBuf::from(&meta.path); let patched_onnx = if let Some(map) = donor_init_map { Some(crate::slicer::onnx_proto::build_patched_onnx( &onnx_path, map, )?) } else { None }; let effective_onnx: &Path = patched_onnx .as_ref() .map_or(onnx_path.as_path(), |t| t.path()); if node.use_circuit { let circuit_path = resolve_circuit_path_required( slices_dir, meta.jstprove_circuit_path.as_deref(), slice_id, )?; let params = backend.load_params(&circuit_path)?; let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs); if donor_init_map.is_some() && !is_wai { return Err(DsperseError::Pipeline(format!( "{slice_id}: consumer weights require circuits compiled with --weights-as-inputs" ))); } let named = if multi_input { run_onnx_inference_multi_named(effective_onnx, tensor_cache, &inputs)? } else { let input_tensor = tensor_cache.gather(&inputs[..1])?; run_onnx_inference_named(effective_onnx, &input_tensor)? }; let outputs = collect_named_outputs(&meta.dependencies.output, named)?; let flat_activations = flatten_cached_inputs(tensor_cache, &inputs)?; let witness_bytes = if is_wai { generate_wai_witness( backend, &circuit_path, &onnx_path, donor_init_map, params.as_ref().unwrap(), &flat_activations, )? } else { backend.witness_f64(&circuit_path, &flat_activations, &[])? }; let witness_path = slice_run_dir.join(crate::utils::paths::WITNESS_FILE); std::fs::write(&witness_path, &witness_bytes) .map_err(|e| DsperseError::io(e, &witness_path))?; Ok(crate::schema::execution::StrategyOutput { info: ExecutionInfo { method: ExecutionMethod::JstproveGenWitness, success: true, error: None, witness_file: Some(witness_path.to_string_lossy().into_owned()), tile_exec_infos: Vec::new(), }, outputs, }) } else { let named = if multi_input { run_onnx_inference_multi_named(effective_onnx, tensor_cache, &inputs)? } else { let input_tensor = tensor_cache.gather(&inputs)?; run_onnx_inference_named(effective_onnx, &input_tensor)? }; let outputs = collect_named_outputs(&meta.dependencies.output, named)?; Ok(crate::schema::execution::StrategyOutput { info: ExecutionInfo { method: ExecutionMethod::OnnxOnly, success: true, error: None, witness_file: None, tile_exec_infos: Vec::new(), }, outputs, }) } } #[cfg(test)] fn store_named_outputs( tensor_cache: &mut TensorStore, output_names: &[String], named_outputs: HashMap, Vec)>, ) -> Result<()> { for (name, tensor) in collect_named_outputs(output_names, named_outputs)? { tensor_cache.put(name, tensor); } Ok(()) } fn collect_named_outputs( output_names: &[String], mut named_outputs: HashMap, Vec)>, ) -> Result)>> { let mut seen = std::collections::HashSet::new(); let mut result = Vec::new(); for name in output_names { if !seen.insert(name) { return Err(DsperseError::Pipeline(format!( "duplicate declared output '{name}'" ))); } let (data, shape) = named_outputs .remove(name) .ok_or_else(|| DsperseError::Pipeline(format!("missing declared output '{name}'")))?; let arr = ArrayD::from_shape_vec(IxDyn(&shape), data) .map_err(|e| DsperseError::Pipeline(format!("output reshape '{name}': {e}")))?; result.push((name.clone(), arr)); } Ok(result) } pub(crate) fn run_onnx_inference(onnx_path: &Path, input: &ArrayD) -> Result> { let input_flat: Vec = input.iter().copied().collect(); let input_shape = input.shape(); let (output_data, output_shape) = crate::backend::onnx::run_inference(onnx_path, &input_flat, input_shape)?; ArrayD::from_shape_vec(IxDyn(&output_shape), output_data) .map_err(|e| DsperseError::Pipeline(format!("output reshape: {e}"))) } pub(crate) fn run_onnx_inference_named( onnx_path: &Path, input: &ArrayD, ) -> Result { let input_flat: Vec = input.iter().copied().collect(); let input_shape = input.shape(); crate::backend::onnx::run_inference_named(onnx_path, &input_flat, input_shape) } pub(crate) fn run_onnx_inference_multi_named( onnx_path: &Path, tensor_cache: &TensorStore, input_names: &[String], ) -> Result { let inputs: Vec<(&str, Vec, Vec)> = input_names .iter() .map(|name| { let arr = tensor_cache.get(name)?; Ok(( name.as_str(), arr.iter().copied().collect(), arr.shape().to_vec(), )) }) .collect::>>()?; crate::backend::onnx::run_inference_multi_named(onnx_path, &inputs) } pub(crate) fn build_execution_chain( model_meta: &ModelMetadata, slices_dir: &Path, ) -> Result { let mut nodes = HashMap::new(); let mut head = None; for (i, slice) in model_meta.slices.iter().enumerate() { let slice_id = format!("slice_{}", slice.index); let slice_dir = slice_dir_path(slices_dir, slice.index); if i == 0 { head = Some(slice_id.clone()); } let bundle = slice_dir.join("jstprove/circuit.bundle"); let (has_circuit, circuit_path) = if bundle.is_dir() { let rel = format!("slice_{}/jstprove/circuit.bundle", slice.index); (true, Some(rel)) } else { (false, None) }; let next = model_meta .slices .get(i + 1) .map(|s| format!("slice_{}", s.index)); let onnx_path = Some( slice .resolve_onnx(slices_dir)? .to_string_lossy() .into_owned(), ); let backend = if has_circuit { BackendKind::Jstprove } else { BackendKind::Onnx }; nodes.insert( slice_id.clone(), ExecutionNode { slice_id: slice_id.clone(), primary: Some(backend.to_string()), fallbacks: if has_circuit { vec!["onnx".into()] } else { Vec::new() }, use_circuit: has_circuit, next, circuit_path, onnx_path, backend, }, ); } Ok(ExecutionChain { head, nodes, fallback_map: HashMap::new(), execution_results: Vec::new(), jstprove_proved_slices: 0, jstprove_verified_slices: 0, }) } pub(crate) fn build_run_metadata( model_meta: &ModelMetadata, slices_dir: &Path, chain: &ExecutionChain, ) -> Result { let mut slices = HashMap::new(); for slice in &model_meta.slices { let slice_id = format!("slice_{}", slice.index); let node = chain.nodes.get(&slice_id); let has_circuit = node.is_some_and(|n| n.use_circuit); let run_slice = RunSliceMetadata { path: slice .resolve_onnx(slices_dir)? .to_string_lossy() .into_owned(), input_shape: slice.shape.tensor_shape.input.clone(), output_shape: slice.shape.tensor_shape.output.clone(), dependencies: slice.dependencies.clone(), tiling: slice.tiling.clone(), channel_split: slice.channel_split.clone(), dim_split: slice.dim_split.clone(), backend: if has_circuit { BackendKind::Jstprove } else { BackendKind::Onnx }, jstprove_circuit_path: node.and_then(|n| n.circuit_path.clone()), jstprove_settings_path: None, }; slices.insert(slice_id, run_slice); } Ok(RunMetadata { slices, execution_chain: chain.clone(), packaging_type: None, source_path: Some(slices_dir.to_string_lossy().into_owned()), run_directory: None, model_path: None, }) } pub(crate) fn extract_initializers_from_map( init_map: &HashMap, params: &CircuitParams, ) -> Result, Vec)>> { let mut initializers = Vec::new(); for io in ¶ms.inputs { if let Some(tensor) = init_map.get(&io.name) { let f32_vals = crate::slicer::onnx_proto::tensor_to_f32(tensor); let mut f64_vals: Vec = f32_vals.iter().map(|&v| f64::from(v)).collect(); let target_shape = &io.shape; let tensor_shape: Vec = tensor.dims.iter().map(|&d| d as usize).collect(); let target_elems: usize = target_shape.iter().product(); if f64_vals.len() < target_elems && !target_shape.is_empty() && !tensor_shape.is_empty() { let is_bias = tensor_shape.len() == 1; let pad_val: f64 = if is_bias { -10.0 } else { 0.0 }; let last = target_shape.len() - 1; let target_last = target_shape[last]; let donor_last = tensor_shape[last]; if donor_last < target_last { let rows = f64_vals.len() / donor_last.max(1); let mut padded = Vec::with_capacity(target_elems); for row in 0..rows { let start = row * donor_last; let end = start + donor_last; padded.extend_from_slice(&f64_vals[start..end.min(f64_vals.len())]); padded.resize(padded.len() + (target_last - donor_last), pad_val); } f64_vals = padded; } } let shape: Vec = if f64_vals.len() == target_elems { target_shape.clone() } else { tensor_shape }; initializers.push((f64_vals, shape)); } } Ok(initializers) } pub fn extract_onnx_initializers( onnx_path: &Path, params: &CircuitParams, ) -> Result, Vec)>> { let model = crate::slicer::onnx_proto::load_model(onnx_path)?; let graph = model .graph .as_ref() .ok_or_else(|| DsperseError::Pipeline("ONNX model missing graph".into()))?; let init_map = crate::slicer::onnx_proto::build_initializer_map(graph); extract_initializers_from_map(&init_map, params) } pub(crate) fn flatten_cached_inputs(cache: &TensorStore, names: &[String]) -> Result> { let arrays: Vec<&ArrayD> = names.iter().map(|n| cache.get(n)).collect::>()?; let total: usize = arrays.iter().map(|a| a.len()).sum(); let mut flat = Vec::with_capacity(total); for arr in arrays { flat.extend(arr.iter()); } Ok(flat) } pub(crate) fn generate_wai_witness( backend: &JstproveBackend, circuit_path: &Path, slice_onnx_path: &Path, donor_init_map: Option<&HashMap>, params: &CircuitParams, flat_activations: &[f64], ) -> Result> { let initializers = if let Some(donor) = donor_init_map { let slice_model = crate::slicer::onnx_proto::load_model(slice_onnx_path)?; let slice_graph = slice_model .graph .as_ref() .ok_or_else(|| DsperseError::Pipeline("slice ONNX missing graph".into()))?; let mut merged = crate::slicer::onnx_proto::build_initializer_map(slice_graph); for (k, v) in donor.iter() { merged.insert(k.clone(), *v); } extract_initializers_from_map(&merged, params)? } else { extract_onnx_initializers(slice_onnx_path, params)? }; backend.witness_f64(circuit_path, flat_activations, &initializers) } #[cfg(test)] mod tests { use super::super::tiled::{reconstruct_from_tiles, reshape_to_4d, split_into_tiles}; use super::*; use crate::schema::tiling::TilingInfo; use ndarray::Array4; fn make_tiling( tile_size: usize, tiles_y: usize, tiles_x: usize, halo: [i64; 4], out_tile: [i64; 2], c_out: usize, ) -> TilingInfo { TilingInfo { slice_idx: 0, tile_size, num_tiles: tiles_y * tiles_x, tiles_y, tiles_x, halo, out_tile, stride: [1, 1], c_in: 1, c_out, input_name: "input".into(), output_name: "output".into(), input_names: vec![], ndim: 4, h: tiles_y * tile_size, w: tiles_x * tile_size, tile: None, tiles: None, segment_size: None, total_elements: None, original_shape: vec![], } } #[test] fn reshape_to_4d_valid() { let data: Vec = (0..24).map(|i| i as f64).collect(); let arr = reshape_to_4d(&data, 2, 3, 4).unwrap(); assert_eq!(arr.dim(), (1, 2, 3, 4)); } #[test] fn reshape_to_4d_single_element() { let data = vec![42.0]; let arr = reshape_to_4d(&data, 1, 1, 1).unwrap(); assert_eq!(arr.dim(), (1, 1, 1, 1)); assert_eq!(arr[[0, 0, 0, 0]], 42.0); } #[test] fn reshape_to_4d_mismatch() { let data = vec![1.0; 10]; assert!(reshape_to_4d(&data, 2, 3, 4).is_err()); } #[test] fn reshape_to_4d_empty() { let data: Vec = vec![]; assert!(reshape_to_4d(&data, 1, 1, 1).is_err()); } #[test] fn split_into_tiles_2x2_no_halo() { let input = Array4::from_shape_vec((1, 1, 4, 4), (0..16).map(|i| i as f64).collect()).unwrap(); let tiling = make_tiling(2, 2, 2, [0, 0, 0, 0], [2, 2], 1); let tiles = split_into_tiles(&input, &tiling).unwrap(); assert_eq!(tiles.len(), 4); for tile in &tiles { assert_eq!(tile.dim(), (1, 1, 2, 2)); } } #[test] fn split_into_tiles_with_halo() { let input = Array4::from_shape_vec((1, 1, 4, 4), (0..16).map(|i| i as f64).collect()).unwrap(); let tiling = make_tiling(2, 2, 2, [1, 1, 1, 1], [2, 2], 1); let tiles = split_into_tiles(&input, &tiling).unwrap(); assert_eq!(tiles.len(), 4); for tile in &tiles { assert_eq!(tile.dim(), (1, 1, 4, 4)); } } #[test] fn split_into_tiles_negative_halo_rejected() { let input = Array4::zeros((1, 1, 4, 4)); let tiling = make_tiling(2, 2, 2, [-1, 0, 0, 0], [2, 2], 1); assert!(split_into_tiles(&input, &tiling).is_err()); } #[test] fn split_into_tiles_batch_gt1_rejected() { let input = Array4::zeros((2, 1, 4, 4)); let tiling = make_tiling(2, 1, 1, [0, 0, 0, 0], [2, 2], 1); assert!(split_into_tiles(&input, &tiling).is_err()); } #[test] fn reconstruct_from_tiles_2x2() { let c_out = 1; let out_h = 2usize; let out_w = 2usize; let tiling = make_tiling(4, 2, 2, [0, 0, 0, 0], [out_h as i64, out_w as i64], c_out); let tiles: Vec> = (0..4) .map(|i| { ArrayD::from_shape_vec( IxDyn(&[1, c_out, out_h, out_w]), vec![i as f64; c_out * out_h * out_w], ) .unwrap() }) .collect(); let output = reconstruct_from_tiles(&tiles, &tiling).unwrap(); assert_eq!(output.shape(), &[1, c_out, 4, 4]); } #[test] fn reconstruct_from_tiles_empty() { let tiling = make_tiling(2, 1, 1, [0, 0, 0, 0], [2, 2], 1); assert!(reconstruct_from_tiles(&[], &tiling).is_err()); } #[test] fn reconstruct_from_tiles_wrong_element_count() { let tiling = make_tiling(2, 1, 1, [0, 0, 0, 0], [2, 2], 1); let bad_tile = vec![ArrayD::from_shape_vec(IxDyn(&[3]), vec![1.0; 3]).unwrap()]; assert!(reconstruct_from_tiles(&bad_tile, &tiling).is_err()); } #[test] fn reconstruct_from_tiles_wrong_tile_count() { let c_out = 1; let out_h = 2i64; let out_w = 2i64; let tiling = make_tiling(4, 2, 2, [0, 0, 0, 0], [out_h, out_w], c_out); let make_tile = || { ArrayD::from_shape_vec( IxDyn(&[1, c_out, out_h as usize, out_w as usize]), vec![0.0f64; c_out * out_h as usize * out_w as usize], ) .unwrap() }; let too_few: Vec> = (0..3).map(|_| make_tile()).collect(); assert!(reconstruct_from_tiles(&too_few, &tiling).is_err()); let too_many: Vec> = (0..5).map(|_| make_tile()).collect(); assert!(reconstruct_from_tiles(&too_many, &tiling).is_err()); } #[test] fn split_reconstruct_roundtrip() { let c = 2; let h = 8; let w = 8; let data: Vec = (0..(c * h * w)).map(|i| i as f64).collect(); let input = Array4::from_shape_vec((1, c, h, w), data).unwrap(); let tile_size = 4; let tiling = make_tiling(tile_size, 2, 2, [0, 0, 0, 0], [4, 4], c); let tiles = split_into_tiles(&input, &tiling).unwrap(); assert_eq!(tiles.len(), 4); let tile_outputs: Vec> = tiles.into_iter().map(|t| t.into_dyn()).collect(); let reconstructed = reconstruct_from_tiles(&tile_outputs, &tiling).unwrap(); assert_eq!(reconstructed.shape(), &[1, c, h, w]); let input_dyn = input.into_dyn(); assert_eq!(input_dyn, reconstructed); } #[test] fn store_named_outputs_basic() { let mut cache = TensorStore::new(); let names = vec!["out_a".to_string(), "out_b".to_string()]; let mut named = HashMap::new(); named.insert("out_a".to_string(), (vec![1.0, 2.0], vec![2])); named.insert("out_b".to_string(), (vec![3.0], vec![1])); store_named_outputs(&mut cache, &names, named).unwrap(); assert_eq!(cache.get("out_a").unwrap().shape(), &[2]); assert_eq!(cache.get("out_b").unwrap().shape(), &[1]); } #[test] fn store_named_outputs_missing_name_errors() { let mut cache = TensorStore::new(); let names = vec!["missing".to_string()]; let named = HashMap::new(); let result = store_named_outputs(&mut cache, &names, named); assert!(result.is_err()); } #[test] fn store_named_outputs_partial_write_errors() { let mut cache = TensorStore::new(); cache.put( "pre_existing".into(), ArrayD::from_shape_vec(ndarray::IxDyn(&[1]), vec![99.0]).unwrap(), ); let names = vec!["present".to_string(), "missing".to_string()]; let mut named = HashMap::new(); named.insert("present".to_string(), (vec![1.0, 2.0], vec![2])); let result = store_named_outputs(&mut cache, &names, named); assert!(result.is_err()); assert!(cache.contains("pre_existing")); assert!(!cache.contains("present")); } #[test] fn run_config_default() { let config = RunConfig::default(); assert_eq!(config.parallel, 1); assert!(!config.batch); assert!(config.weights_onnx.is_none()); assert!(config.combined); } #[test] fn multi_input_activation_concatenation_ordering() { use ndarray::IxDyn; let mut cache = TensorStore::new(); cache.put( "act_a".into(), ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), ); cache.put( "act_b".into(), ArrayD::from_shape_vec(IxDyn(&[2]), vec![7.0, 8.0]).unwrap(), ); cache.put( "act_c".into(), ArrayD::from_shape_vec(IxDyn(&[1]), vec![9.0]).unwrap(), ); let inputs = vec![ "act_a".to_string(), "act_b".to_string(), "act_c".to_string(), ]; let mut flat: Vec = Vec::new(); for name in &inputs { let arr = cache.get(name).unwrap(); flat.extend(arr.iter()); } assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); } #[test] fn multi_input_activation_missing_tensor_error() { let mut cache = TensorStore::new(); cache.put( "act_a".into(), ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[2]), vec![1.0, 2.0]).unwrap(), ); let inputs = vec!["act_a".to_string(), "act_missing".to_string()]; let mut flat: Vec = Vec::new(); let mut err = None; for name in &inputs { match cache.get(name) { Ok(arr) => flat.extend(arr.iter()), Err(e) => { err = Some(e); break; } } } assert!(err.is_some()); assert!(err.unwrap().to_string().contains("act_missing")); } } ================================================ FILE: crates/dsperse/src/pipeline/slice_cache.rs ================================================ use std::io::Read; use std::path::Path; use crate::error::{DsperseError, Result}; pub struct SliceAssets { pub circuit_bytes: Option>, pub onnx_bytes: Option>, } impl SliceAssets { pub fn load_from_dslice(slices_dir: &Path, slice_id: &str) -> Result { let archive_path = slices_dir.join(format!("{slice_id}.dslice")); if !archive_path.exists() { return Ok(Self { circuit_bytes: None, onnx_bytes: None, }); } let file = std::fs::File::open(&archive_path).map_err(|e| DsperseError::io(e, &archive_path))?; let mut zip = zip::ZipArchive::new(file).map_err(|e| { DsperseError::Slicer(format!( "reading dslice archive {}: {e}", archive_path.display() )) })?; let mut circuit_bytes = None; let mut onnx_bytes = None; for i in 0..zip.len() { let mut entry = zip .by_index(i) .map_err(|e| DsperseError::Slicer(format!("reading zip entry {i}: {e}")))?; let name = entry.name().to_string(); if name.ends_with("circuit.bin") { let mut buf = Vec::with_capacity(entry.size() as usize); entry.read_to_end(&mut buf).map_err(|e| { DsperseError::Slicer(format!("reading circuit.bin from dslice: {e}")) })?; circuit_bytes = Some(buf); } else if name.ends_with(".onnx") && name.starts_with("payload/") { let mut buf = Vec::with_capacity(entry.size() as usize); entry .read_to_end(&mut buf) .map_err(|e| DsperseError::Slicer(format!("reading onnx from dslice: {e}")))?; onnx_bytes = Some(buf); } } Ok(Self { circuit_bytes, onnx_bytes, }) } } ================================================ FILE: crates/dsperse/src/pipeline/stage.rs ================================================ use std::path::Path; use rayon::prelude::*; use crate::backend::ProofBackend; use crate::error::{DsperseError, Result}; use crate::schema::execution::{ExecutionMethod, RunMetadata, SliceResult, TileResult}; use crate::schema::metadata::RunSliceMetadata; use crate::schema::tiling::TilingInfo; use crate::utils::paths::resolve_relative_path; use super::tile_executor::resolve_tile_circuit; #[derive(Debug, Clone, Copy)] pub enum PipelineStage { Prove, Verify, } impl PipelineStage { fn execution_method(&self) -> ExecutionMethod { match self { Self::Prove => ExecutionMethod::JstproveProve, Self::Verify => ExecutionMethod::JstproveVerify, } } fn action_label(&self) -> &'static str { match self { Self::Prove => "proving", Self::Verify => "verifying", } } fn past_label(&self) -> &'static str { match self { Self::Prove => "proved", Self::Verify => "verified", } } fn error_label(&self) -> &'static str { match self { Self::Prove => "proof", Self::Verify => "verification", } } } pub fn run_pipeline_stage( stage: PipelineStage, run_dir: &Path, slices_dir: &Path, backend: &dyn ProofBackend, parallel: usize, ) -> Result { let meta_path = run_dir.join(crate::utils::paths::METADATA_FILE); let data = crate::utils::limits::read_checked(&meta_path)?; let mut run_meta: RunMetadata = rmp_serde::from_slice(&data)?; let circuit_slices: Vec<(String, _)> = run_meta .iter_circuit_slices() .map(|(id, meta)| (id.to_string(), meta.clone())) .collect(); tracing::info!( total = circuit_slices.len(), "{} circuit slices", stage.action_label() ); let pool = rayon::ThreadPoolBuilder::new() .num_threads(parallel) .build() .map_err(|e| DsperseError::Pipeline(format!("thread pool: {e}")))?; let results: Vec<_> = pool.install(|| { circuit_slices .par_iter() .map(|(slice_id, meta)| { if slice_id.strip_prefix("slice_").and_then(|s| s.parse::().ok()).is_none() { return ( slice_id.clone(), Err(DsperseError::Pipeline(format!( "invalid slice_id format: {slice_id:?}" ))), ); } let slice_run_dir = run_dir.join(slice_id); let result = execute_single_slice(stage, slices_dir, &slice_run_dir, slice_id, meta, backend); match &result { Ok(r) if r.success => tracing::info!(slice = %slice_id, "{}", stage.past_label()), Ok(r) => tracing::error!( slice = %slice_id, error = r.error.as_deref().unwrap_or("unknown"), "{} failed", stage.error_label() ), Err(e) => tracing::error!(slice = %slice_id, error = %e, "{} error", stage.error_label()), } (slice_id.clone(), result) }) .collect() }); let method = stage.execution_method(); let mut succeeded = 0; for (slice_id, result) in results { let slice_result = match result { Ok(r) => { if r.success { succeeded += 1; } r } Err(e) => SliceResult::failure(slice_id.clone(), method, e.to_string(), 0.0), }; if let Some(entry) = run_meta .execution_chain .execution_results .iter_mut() .find(|e| e.slice_id == slice_id) { match stage { PipelineStage::Prove => entry.proof_execution = Some(slice_result), PipelineStage::Verify => entry.verification_execution = Some(slice_result), } } else { tracing::warn!( slice = %slice_id, stage = ?stage, success = slice_result.success, error = slice_result.error.as_deref().unwrap_or("none"), "no matching execution_results entry, result dropped" ); } } match stage { PipelineStage::Prove => run_meta.execution_chain.jstprove_proved_slices = succeeded, PipelineStage::Verify => run_meta.execution_chain.jstprove_verified_slices = succeeded, } let meta_bytes = rmp_serde::to_vec_named(&run_meta)?; std::fs::write(&meta_path, meta_bytes).map_err(|e| DsperseError::io(e, &meta_path))?; tracing::info!( succeeded, total = circuit_slices.len(), "{} complete", stage.action_label() ); Ok(run_meta) } fn execute_single_slice( stage: PipelineStage, slices_dir: &Path, slice_run_dir: &Path, slice_id: &str, meta: &RunSliceMetadata, backend: &dyn ProofBackend, ) -> Result { if let Some(ref tiling) = meta.tiling { let default_circuit_path = meta .jstprove_circuit_path .as_deref() .map(|p| resolve_relative_path(slices_dir, p)) .transpose()?; return execute_tiled_stage( stage, slice_id, default_circuit_path.as_deref(), slice_run_dir, tiling, slices_dir, backend, ); } let circuit_path = meta .jstprove_circuit_path .as_deref() .map(|p| resolve_relative_path(slices_dir, p)) .transpose()? .ok_or_else(|| DsperseError::Pipeline(format!("no circuit path for {slice_id}")))?; let start = std::time::Instant::now(); let method = stage.execution_method(); let witness_path = slice_run_dir.join(crate::utils::paths::WITNESS_FILE); let witness_bytes = match crate::utils::limits::read_checked(&witness_path) { Ok(b) => b, Err(e) => { return Ok(SliceResult::failure( slice_id, method, format!("witness file read error: {}: {e}", witness_path.display()), start.elapsed().as_secs_f64(), )); } }; execute_stage_operation( stage, slice_id, &circuit_path, &witness_bytes, slice_run_dir, backend, start, method, ) } #[allow(clippy::too_many_arguments)] fn execute_stage_operation( stage: PipelineStage, slice_id: &str, circuit_path: &Path, witness_bytes: &[u8], output_dir: &Path, backend: &dyn ProofBackend, start: std::time::Instant, method: ExecutionMethod, ) -> Result { match stage { PipelineStage::Prove => { let proof_bytes = backend.prove(circuit_path, witness_bytes)?; let proof_path = output_dir.join(crate::utils::paths::PROOF_FILE); std::fs::write(&proof_path, &proof_bytes) .map_err(|e| DsperseError::io(e, &proof_path))?; let mut result = SliceResult::success(slice_id, method, start.elapsed().as_secs_f64()); result.proof_path = Some(proof_path.to_string_lossy().into_owned()); Ok(result) } PipelineStage::Verify => { let proof_path = output_dir.join(crate::utils::paths::PROOF_FILE); let proof_bytes = match crate::utils::limits::read_checked(&proof_path) { Ok(b) => b, Err(e) => { return Ok(SliceResult::failure( slice_id, method, format!("proof file read error: {}: {e}", proof_path.display()), start.elapsed().as_secs_f64(), )); } }; let valid = backend.verify(circuit_path, witness_bytes, &proof_bytes)?; let elapsed = start.elapsed().as_secs_f64(); let mut result = if valid { SliceResult::success(slice_id, method, elapsed) } else { SliceResult::failure( slice_id, method, "proof verification failed".into(), elapsed, ) }; result.proof_path = Some(proof_path.to_string_lossy().into_owned()); Ok(result) } } } fn execute_tiled_stage( stage: PipelineStage, slice_id: &str, default_circuit_path: Option<&Path>, slice_run_dir: &Path, tiling: &TilingInfo, slices_dir: &Path, backend: &dyn ProofBackend, ) -> Result { if tiling.num_tiles == 0 { return Err(DsperseError::Pipeline(format!( "{slice_id}: tiling.num_tiles is 0" ))); } let start = std::time::Instant::now(); let method = stage.execution_method(); let tile_results: Vec = (0..tiling.num_tiles) .into_par_iter() .map(|tile_idx| { let tile_start = std::time::Instant::now(); let fail = |error: String| { TileResult::failure( tile_idx, error, Some(method), tile_start.elapsed().as_secs_f64(), ) }; let tile_dir = slice_run_dir.join(format!("tile_{tile_idx}")); let tile_circuit_path = match resolve_tile_circuit(tiling, tile_idx, slices_dir, default_circuit_path) { Ok(Some(p)) => p, Ok(None) => return fail(format!("no circuit path for tile {tile_idx}")), Err(e) => return fail(e), }; let witness_path = tile_dir.join(crate::utils::paths::WITNESS_FILE); let witness_bytes = match crate::utils::limits::read_checked(&witness_path) { Ok(b) => b, Err(e) => { return fail(format!( "witness read error: {}: {e}", witness_path.display() )); } }; execute_tile_stage_operation( stage, tile_idx, &tile_circuit_path, &witness_bytes, &tile_dir, backend, method, tile_start, ) }) .collect(); let failed = tile_results.iter().filter(|t| !t.success).count(); let all_success = failed == 0; let elapsed = start.elapsed().as_secs_f64(); let mut result = if all_success { SliceResult::success(slice_id, method, elapsed) } else { SliceResult::failure( slice_id, method, format!("{failed} of {} tiles failed", tiling.num_tiles), elapsed, ) }; result.tiles = tile_results; Ok(result) } #[allow(clippy::too_many_arguments)] fn execute_tile_stage_operation( stage: PipelineStage, tile_idx: usize, circuit_path: &Path, witness_bytes: &[u8], tile_dir: &Path, backend: &dyn ProofBackend, method: ExecutionMethod, tile_start: std::time::Instant, ) -> TileResult { let fail = |error: String| { TileResult::failure( tile_idx, error, Some(method), tile_start.elapsed().as_secs_f64(), ) }; match stage { PipelineStage::Prove => { let proof_bytes = match backend.prove(circuit_path, witness_bytes) { Ok(b) => b, Err(e) => return fail(e.to_string()), }; let proof_path = tile_dir.join(crate::utils::paths::PROOF_FILE); if let Err(e) = std::fs::write(&proof_path, &proof_bytes) { return fail(format!("write proof: {}: {e}", proof_path.display())); } let mut result = TileResult::success(tile_idx, Some(method), tile_start.elapsed().as_secs_f64()); result.proof_path = Some(proof_path.to_string_lossy().into_owned()); result } PipelineStage::Verify => { let proof_path = tile_dir.join(crate::utils::paths::PROOF_FILE); let proof_bytes = match crate::utils::limits::read_checked(&proof_path) { Ok(b) => b, Err(e) => { return fail(format!("proof read error: {}: {e}", proof_path.display())); } }; let valid = match backend.verify(circuit_path, witness_bytes, &proof_bytes) { Ok(v) => v, Err(e) => return fail(e.to_string()), }; let elapsed = tile_start.elapsed().as_secs_f64(); let mut result = if valid { TileResult::success(tile_idx, Some(method), elapsed) } else { TileResult::failure( tile_idx, "proof verification failed".into(), Some(method), elapsed, ) }; result.proof_path = Some(proof_path.to_string_lossy().into_owned()); result } } } ================================================ FILE: crates/dsperse/src/pipeline/strategy.rs ================================================ use crate::error::{DsperseError, Result}; use crate::schema::execution::ExecutionMethod; use crate::schema::metadata::RunSliceMetadata; use crate::schema::tiling::{ChannelSplitInfo, DimSplitInfo, SplitStrategy, TilingInfo}; pub enum ExecutionStrategy<'a> { ChannelSplit(&'a ChannelSplitInfo), DimSplit(&'a DimSplitInfo), Tiled(&'a TilingInfo), Single { use_circuit: bool }, } impl<'a> ExecutionStrategy<'a> { pub fn from_metadata(meta: &'a RunSliceMetadata, use_circuit: bool) -> Result { let has_cs = meta.channel_split.is_some(); let has_ds = meta.dim_split.is_some(); let has_tiling = meta.tiling.is_some(); let count = has_cs as u8 + has_ds as u8 + has_tiling as u8; if count > 1 { return Err(DsperseError::Metadata(format!( "slice has multiple split metadata (channel_split={has_cs}, \ dim_split={has_ds}, tiling={has_tiling}; path={:?})", meta.path ))); } match meta.split_strategy() { Some(SplitStrategy::ChannelSplit(cs)) => Ok(Self::ChannelSplit(cs)), Some(SplitStrategy::DimSplit(ds)) => { if ds.template_path.is_none() { // Template creation may have been rejected (axis- // separability, unsupported split kind) or the template // was not included in the bundle. Fall back to the // non-template Single execution path (which may still // use circuit-based witness generation if use_circuit is // set) so already-published bundles with template-less // dim_split metadata remain runnable. tracing::debug!( path = ?meta.path, split_kind = ?ds.split_kind, "dim_split template_path missing, falling back to single execution" ); Ok(Self::Single { use_circuit }) } else { Ok(Self::DimSplit(ds)) } } Some(SplitStrategy::Tiled(t)) => Ok(Self::Tiled(t)), None => Ok(Self::Single { use_circuit }), } } pub fn execution_method(&self) -> ExecutionMethod { match self { Self::ChannelSplit(_) => ExecutionMethod::ChannelSplit, Self::DimSplit(_) => ExecutionMethod::DimSplit, Self::Tiled(_) => ExecutionMethod::Tiled, Self::Single { use_circuit: true } => ExecutionMethod::JstproveGenWitness, Self::Single { use_circuit: false } => ExecutionMethod::OnnxOnly, } } pub fn output_name(&self) -> Option<&str> { match self { Self::ChannelSplit(cs) => Some(&cs.output_name), Self::DimSplit(ds) => Some(&ds.output_name), Self::Tiled(tiling) => Some(&tiling.output_name), Self::Single { .. } => None, } } } ================================================ FILE: crates/dsperse/src/pipeline/tensor_store.rs ================================================ use std::collections::HashMap; use ndarray::ArrayD; use crate::error::{DsperseError, Result}; #[derive(Default)] pub struct TensorStore { tensors: HashMap>, } impl TensorStore { pub fn new() -> Self { Self::default() } pub fn get(&self, name: &str) -> Result<&ArrayD> { self.tensors .get(name) .ok_or_else(|| DsperseError::Pipeline(format!("tensor '{name}' not found in store"))) } pub fn try_get(&self, name: &str) -> Option<&ArrayD> { self.tensors.get(name) } pub fn put(&mut self, name: String, tensor: ArrayD) { self.tensors.insert(name, tensor); } pub fn contains(&self, name: &str) -> bool { self.tensors.contains_key(name) } pub fn len(&self) -> usize { self.tensors.len() } pub fn is_empty(&self) -> bool { self.tensors.is_empty() } pub fn keys(&self) -> impl Iterator { self.tensors.keys() } pub fn as_map(&self) -> &HashMap> { &self.tensors } pub fn gather(&self, names: &[String]) -> Result> { crate::utils::io::gather_inputs_from_cache(&self.tensors, names) } } #[cfg(test)] mod tests { use super::*; use ndarray::IxDyn; #[test] fn put_and_get() { let mut store = TensorStore::new(); let arr = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap(); store.put("x".into(), arr.clone()); assert_eq!(store.get("x").unwrap(), &arr); } #[test] fn get_missing_returns_error() { let store = TensorStore::new(); assert!(store.get("missing").is_err()); } #[test] fn try_get_missing_returns_none() { let store = TensorStore::new(); assert!(store.try_get("missing").is_none()); } #[test] fn contains_check() { let mut store = TensorStore::new(); assert!(!store.contains("a")); store.put( "a".into(), ArrayD::from_shape_vec(IxDyn(&[1]), vec![0.0]).unwrap(), ); assert!(store.contains("a")); } } ================================================ FILE: crates/dsperse/src/pipeline/tile_executor.rs ================================================ use std::path::{Path, PathBuf}; use rayon::prelude::*; use crate::error::{DsperseError, Result}; use crate::schema::tiling::TilingInfo; use crate::utils::paths::resolve_relative_path; pub fn resolve_tile_circuit( tiling: &TilingInfo, tile_idx: usize, slices_dir: &Path, default_circuit: Option<&Path>, ) -> std::result::Result, String> { let from_tiles = tiling .tiles .as_deref() .and_then(|ts| ts.get(tile_idx)) .and_then(|ti| ti.jstprove_circuit_path.as_deref()); let from_single = tiling .tile .as_ref() .and_then(|ti| ti.jstprove_circuit_path.as_deref()); let path_str = from_tiles.or(from_single); match path_str { Some(p) => match resolve_relative_path(slices_dir, p) { Ok(resolved) => Ok(Some(resolved)), Err(e) => Err(e.to_string()), }, None => Ok(default_circuit.map(|p| p.to_path_buf())), } } pub fn execute_tiles(parallel: usize, num_tiles: usize, op: F) -> Result> where T: Send, F: Fn(usize) -> T + Send + Sync, { if num_tiles == 0 { return Err(DsperseError::Pipeline("num_tiles is 0".into())); } let pool = rayon::ThreadPoolBuilder::new() .num_threads(parallel) .build() .map_err(|e| DsperseError::Pipeline(format!("thread pool: {e}")))?; let results = pool.install(|| (0..num_tiles).into_par_iter().map(op).collect()); Ok(results) } #[cfg(test)] mod tests { use super::*; use crate::schema::tiling::{TileInfo, TilingInfo}; fn make_tiling() -> TilingInfo { TilingInfo { slice_idx: 0, tile_size: 4, num_tiles: 4, tiles_y: 2, tiles_x: 2, halo: [0, 0, 0, 0], out_tile: [4, 4], stride: [1, 1], c_in: 1, c_out: 1, input_name: "input".into(), output_name: "output".into(), input_names: vec![], ndim: 4, h: 8, w: 8, tile: None, tiles: None, segment_size: None, total_elements: None, original_shape: vec![], } } #[test] fn resolve_tile_circuit_no_info() { let tiling = make_tiling(); let result = resolve_tile_circuit(&tiling, 0, Path::new("/tmp"), None); assert_eq!(result.unwrap(), None); } #[test] fn resolve_tile_circuit_with_default() { let tiling = make_tiling(); let default = PathBuf::from("/tmp/circuit.bundle"); let result = resolve_tile_circuit(&tiling, 0, Path::new("/tmp"), Some(&default)); assert_eq!(result.unwrap(), Some(default)); } #[test] fn resolve_tile_circuit_from_single_tile() { let mut tiling = make_tiling(); tiling.tile = Some(TileInfo { path: "tile.onnx".into(), conv_out: [4, 4], jstprove_circuit_path: Some("jstprove/circuit.bundle".into()), }); let result = resolve_tile_circuit(&tiling, 0, Path::new("/slices"), None); let resolved = result.unwrap().unwrap(); assert!(resolved.to_string_lossy().contains("circuit.bundle")); } #[test] fn execute_tiles_collects_results() { let results = execute_tiles(2, 4, |i| i * 2).unwrap(); assert_eq!(results.len(), 4); let mut sorted = results.clone(); sorted.sort(); assert_eq!(sorted, vec![0, 2, 4, 6]); } #[test] fn execute_tiles_zero_tiles_errors() { let result = execute_tiles(1, 0, |i| i); assert!(result.is_err()); } } ================================================ FILE: crates/dsperse/src/pipeline/tiled.rs ================================================ use std::collections::HashMap; use std::path::Path; use std::sync::Arc; use ndarray::{Array4, ArrayD, IxDyn, s}; use rayon::prelude::*; use super::tensor_store::TensorStore; use crate::backend::jstprove::JstproveBackend; use crate::error::{DsperseError, Result}; use crate::schema::execution::{ExecutionInfo, ExecutionMethod, TileResult}; use crate::schema::tiling::TilingInfo; use crate::slicer::onnx_proto::TensorProto; use crate::utils::paths::resolve_relative_path; use super::runner::{ RunConfig, extract_initializers_from_map, extract_onnx_initializers, resolve_circuit_path_optional, run_onnx_inference, }; #[allow(clippy::too_many_arguments)] pub(crate) fn execute_tiled( slices_dir: &Path, slice_run_dir: &Path, slice_id: &str, tiling: &TilingInfo, slice_circuit_path: Option<&Path>, tensor_cache: &TensorStore, backend: &JstproveBackend, config: &RunConfig, donor_init_map: Option<&HashMap>, ) -> Result { let all_names = tiling.all_input_names(); let multi_input = all_names.len() > 1; let is_fixed_segment = tiling.ndim == 1; let is_1d = tiling.ndim == 3; let all_tiles_dyn = if is_fixed_segment { prepare_fixed_segments_from_cache(tiling, tensor_cache)? } else { prepare_tiles_from_cache(tiling, tensor_cache, is_1d)? }; let num_tiles = all_tiles_dyn[0].len(); tracing::info!( slice = %slice_id, num_tiles, tile_size = tiling.tile_size, ndim = tiling.ndim, "splitting into tiles" ); let tile_infos = tiling.tiles.as_deref().unwrap_or(&[]); let single_tile = tiling.tile.as_ref(); if tile_infos.is_empty() && single_tile.is_none() { return Err(DsperseError::Pipeline(format!( "tiling for '{}' has neither tile list nor single tile template", tiling.output_name ))); } let first_tile_info = tile_infos.first().or(single_tile); let first_tile_onnx = first_tile_info .map(|ti| resolve_relative_path(slices_dir, &ti.path)) .transpose()?; let warm_model = if multi_input || is_1d || is_fixed_segment { None } else { match (first_tile_onnx.as_deref(), all_tiles_dyn[0].first()) { (Some(onnx_path), Some(sample)) => { let shape = sample.shape().to_vec(); let model = crate::backend::onnx::WarmModel::load(onnx_path, &shape)?; tracing::info!(slice = %slice_id, "loaded ONNX model"); Some(model) } _ => None, } }; let circuit_path = resolve_circuit_path_optional( slices_dir, first_tile_info.and_then(|ti| ti.jstprove_circuit_path.as_deref()), )? .or_else(|| slice_circuit_path.map(|p| p.to_path_buf())); let warm_circuit = match (&circuit_path, &first_tile_onnx) { (Some(cp), Some(onnx_path)) => { let params = backend.load_params(cp)?; let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs); if donor_init_map.is_some() && !is_wai { return Err(DsperseError::Pipeline(format!( "{slice_id}: consumer weights require circuits compiled with --weights-as-inputs" ))); } let initializers = if is_wai { if let Some(map) = donor_init_map { extract_initializers_from_map(map, params.as_ref().unwrap())? } else { extract_onnx_initializers(onnx_path, params.as_ref().unwrap())? } } else { vec![] }; let wc = crate::backend::jstprove::WarmCircuit::load(cp, initializers, backend)?; tracing::info!(slice = %slice_id, wai = is_wai, "loaded circuit bundle"); Some(wc) } _ => None, }; let warm_model = warm_model.map(Arc::new); let warm_circuit = warm_circuit.map(Arc::new); let circuit_path = circuit_path.map(Arc::from); let pool = rayon::ThreadPoolBuilder::new() .num_threads(config.parallel) .build() .map_err(|e| DsperseError::Pipeline(format!("thread pool: {e}")))?; let tile_input_names: Vec = if all_names.len() > 1 { (0..all_names.len()) .map(|i| format!("tile_in_{i}")) .collect() } else { vec!["tile_in".to_string()] }; let collected: Vec<(TileResult, Option>)> = pool.install(|| { (0..num_tiles) .into_par_iter() .map(|tile_idx| { let start = std::time::Instant::now(); let tile_dir = slice_run_dir.join(format!("tile_{tile_idx}")); if let Err(e) = std::fs::create_dir_all(&tile_dir) { return ( TileResult::failure( tile_idx, format!("mkdir: {e}"), None, start.elapsed().as_secs_f64(), ), None, ); } let tile_info = tile_infos.get(tile_idx).or(single_tile); let tile_dyn = all_tiles_dyn[0][tile_idx].clone(); let per_tile_onnx = tile_info .map(|ti| resolve_relative_path(slices_dir, &ti.path)) .transpose(); let per_tile_onnx = match per_tile_onnx { Ok(p) => p, Err(e) => { return ( TileResult::failure( tile_idx, format!("resolve tile path: {e}"), None, start.elapsed().as_secs_f64(), ), None, ); } }; let effective_tile_onnx_ref = per_tile_onnx.as_deref(); if tile_info.is_none() { return ( TileResult::failure( tile_idx, "no tile circuit info".into(), None, start.elapsed().as_secs_f64(), ), None, ); } let tile_output = if multi_input || is_1d || is_fixed_segment { if let Some(onnx) = effective_tile_onnx_ref { let inputs: Vec<(&str, Vec, Vec)> = all_tiles_dyn .iter() .zip(tile_input_names.iter()) .map(|(input_tiles, tile_name)| { let t = &input_tiles[tile_idx]; let shape: Vec = t.shape().to_vec(); let data: Vec = t.iter().copied().collect(); (tile_name.as_str(), data, shape) }) .collect(); crate::backend::onnx::run_inference_multi_named(onnx, &inputs).and_then( |named| { let (data, shape) = named.into_values().next().ok_or_else(|| { DsperseError::Pipeline( "multi-input tile produced no output".into(), ) })?; ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| { DsperseError::Pipeline(format!( "multi-input tile output reshape: {e}" )) }) }, ) } else { Err(DsperseError::Pipeline(format!( "tile {tile_idx}: no ONNX model available for inference" ))) } } else if let Some(ref wm) = warm_model { let input_flat: Vec = tile_dyn.iter().copied().collect(); wm.run(&input_flat).and_then(|(data, shape)| { ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| { crate::error::DsperseError::Pipeline(format!( "warm model output reshape: {e}" )) }) }) } else if let Some(onnx) = effective_tile_onnx_ref { run_onnx_inference(onnx, &tile_dyn) } else { Err(DsperseError::Pipeline(format!( "tile {tile_idx}: no ONNX model available for inference" ))) }; let output_tensor = match tile_output { Ok(t) => t, Err(e) => { return ( TileResult::failure( tile_idx, format!("onnx inference: {e}"), Some(ExecutionMethod::OnnxOnly), start.elapsed().as_secs_f64(), ), None, ); } }; if circuit_path.is_none() { return ( TileResult::success( tile_idx, Some(ExecutionMethod::OnnxOnly), start.elapsed().as_secs_f64(), ), Some(output_tensor), ); } let flat: Vec = flatten_tile_inputs(&all_tiles_dyn, tile_idx); let witness_result = if let Some(ref wc) = warm_circuit { wc.witness_f64(&flat) } else { let cp = circuit_path .as_ref() .expect("circuit_path is Some: guarded by early return"); backend.witness_f64(cp, &flat, &[]) }; match witness_result { Ok(witness_bytes) => { let witness_path = tile_dir.join(crate::utils::paths::WITNESS_FILE); if let Err(e) = std::fs::write(&witness_path, &witness_bytes) { return ( TileResult::failure( tile_idx, format!("write witness: {e}"), Some(ExecutionMethod::JstproveGenWitness), start.elapsed().as_secs_f64(), ), None, ); } ( TileResult::success( tile_idx, Some(ExecutionMethod::JstproveGenWitness), start.elapsed().as_secs_f64(), ), Some(output_tensor), ) } Err(e) => ( TileResult::failure( tile_idx, e.to_string(), Some(ExecutionMethod::JstproveGenWitness), start.elapsed().as_secs_f64(), ), None, ), } }) .collect() }); let mut tile_results: Vec = Vec::with_capacity(collected.len()); let mut tile_outputs: Vec> = Vec::with_capacity(collected.len()); for (result, output) in collected { if let Some(o) = output { tile_outputs.push(o); } tile_results.push(result); } if tile_results.is_empty() { return Err(DsperseError::Pipeline(format!( "tiling produced zero tiles for '{}'", tiling.output_name ))); } let all_success = tile_results.iter().all(|r| r.success); if !all_success { let failed: Vec<_> = tile_results .iter() .filter(|r| !r.success) .map(|r| format!("tile {}: {}", r.tile_idx, r.error.as_deref().unwrap_or("?"))) .collect(); return Err(DsperseError::Pipeline(format!( "tiled execution failed for '{}': {}", tiling.output_name, failed.join("; ") ))); } debug_assert!( !tile_outputs.is_empty(), "all tiles reported success but no outputs for '{}'", tiling.output_name ); let reconstructed = if is_fixed_segment { reconstruct_from_fixed_segments(&tile_outputs, tiling)? } else if is_1d { let r = reconstruct_from_tiles_1d(&tile_outputs, tiling)?; trim_to_original_seq(r, tiling)? } else { let r = reconstruct_from_tiles(&tile_outputs, tiling)?; trim_to_original_dims(r, tiling)? }; Ok(crate::schema::execution::StrategyOutput { info: ExecutionInfo { method: ExecutionMethod::Tiled, success: true, error: None, witness_file: None, tile_exec_infos: tile_results, }, outputs: vec![(tiling.output_name.clone(), reconstructed)], }) } /// Witness-only tiled execution for combined inference mode. /// /// The full-model ONNX inference has already run and populated the tensor /// cache with all intermediate activations. This function splits those /// cached activations into tiles, generates per-tile ZK witnesses via the /// circuit backend, and returns tile-level execution results. It does NOT /// reconstruct output tensors — those already exist in the cache from the /// monolithic inference pass — hence the empty `outputs` vec in the /// returned `StrategyOutput`. #[allow(clippy::too_many_arguments)] pub(crate) fn execute_combined_tiled( slices_dir: &Path, slice_run_dir: &Path, slice_id: &str, tiling: &TilingInfo, slice_circuit_path: Option<&str>, tensor_cache: &TensorStore, backend: &JstproveBackend, config: &RunConfig, donor_init_map: Option<&HashMap>, ) -> Result { let is_fixed_segment = tiling.ndim == 1; let is_1d = tiling.ndim == 3; let all_tiles_dyn = if is_fixed_segment { prepare_fixed_segments_from_cache(tiling, tensor_cache)? } else { prepare_tiles_from_cache(tiling, tensor_cache, is_1d)? }; let num_tiles = all_tiles_dyn[0].len(); tracing::info!( slice = %slice_id, num_tiles, tile_size = tiling.tile_size, "splitting combined activations into tiles for witness generation" ); let tile_infos = tiling.tiles.as_deref().unwrap_or(&[]); let single_tile = tiling.tile.as_ref(); let first_tile_info = tile_infos.first().or(single_tile); let circuit_path = resolve_circuit_path_optional( slices_dir, first_tile_info .and_then(|ti| ti.jstprove_circuit_path.as_deref()) .or(slice_circuit_path), )?; let circuit_path = match circuit_path { Some(p) => p, None => { return Ok(crate::schema::execution::StrategyOutput { info: ExecutionInfo { method: ExecutionMethod::Tiled, success: true, error: None, witness_file: None, tile_exec_infos: (0..num_tiles) .map(|i| TileResult::success(i, Some(ExecutionMethod::OnnxOnly), 0.0)) .collect(), }, outputs: vec![], }); } }; let first_tile_onnx = first_tile_info .map(|ti| resolve_relative_path(slices_dir, &ti.path)) .transpose()?; let patched_tile_onnx = match (&first_tile_onnx, donor_init_map) { (Some(onnx_path), Some(map)) => Some(crate::slicer::onnx_proto::build_patched_onnx( onnx_path, map, )?), _ => None, }; let effective_tile_onnx = patched_tile_onnx.as_ref().map(|t| t.path().to_path_buf()); let effective_tile_onnx_ref = effective_tile_onnx .as_deref() .or(first_tile_onnx.as_deref()); let params = backend.load_params(&circuit_path)?; let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs); if donor_init_map.is_some() && !is_wai { return Err(DsperseError::Pipeline(format!( "{slice_id}: consumer weights require circuits compiled with --weights-as-inputs" ))); } let warm_circuit = match effective_tile_onnx_ref { Some(onnx_path) => { let initializers = if is_wai { if let Some(map) = donor_init_map { extract_initializers_from_map(map, params.as_ref().unwrap())? } else { extract_onnx_initializers(onnx_path, params.as_ref().unwrap())? } } else { vec![] }; let wc = crate::backend::jstprove::WarmCircuit::load(&circuit_path, initializers, backend)?; tracing::info!(slice = %slice_id, wai = is_wai, "loaded tile circuit for combined tiling"); Some(wc) } None => None, }; let warm_circuit = warm_circuit.map(Arc::new); let circuit_path = Arc::from(circuit_path); let pool = rayon::ThreadPoolBuilder::new() .num_threads(config.parallel) .build() .map_err(|e| DsperseError::Pipeline(format!("thread pool: {e}")))?; let collected: Vec = pool.install(|| { (0..num_tiles) .into_par_iter() .map(|tile_idx| { let start = std::time::Instant::now(); let tile_dir = slice_run_dir.join(format!("tile_{tile_idx}")); if let Err(e) = std::fs::create_dir_all(&tile_dir) { return TileResult::failure( tile_idx, format!("mkdir: {e}"), None, start.elapsed().as_secs_f64(), ); } let flat: Vec = flatten_tile_inputs(&all_tiles_dyn, tile_idx); let witness_result = if let Some(ref wc) = warm_circuit { wc.witness_f64(&flat) } else { backend.witness_f64(&circuit_path, &flat, &[]) }; match witness_result { Ok(witness_bytes) => { let witness_path = tile_dir.join(crate::utils::paths::WITNESS_FILE); if let Err(e) = std::fs::write(&witness_path, &witness_bytes) { return TileResult::failure( tile_idx, format!("write witness: {e}"), Some(ExecutionMethod::JstproveGenWitness), start.elapsed().as_secs_f64(), ); } TileResult::success( tile_idx, Some(ExecutionMethod::JstproveGenWitness), start.elapsed().as_secs_f64(), ) } Err(e) => TileResult::failure( tile_idx, e.to_string(), Some(ExecutionMethod::JstproveGenWitness), start.elapsed().as_secs_f64(), ), } }) .collect() }); let all_success = collected.iter().all(|r| r.success); if !all_success { let failed: Vec<_> = collected .iter() .filter(|r| !r.success) .map(|r| format!("tile {}: {}", r.tile_idx, r.error.as_deref().unwrap_or("?"))) .collect(); return Err(DsperseError::Pipeline(format!( "{slice_id}: tiled witness generation failed: {}", failed.join("; ") ))); } tracing::info!( slice = %slice_id, num_tiles, "tiled witness generation from combined outputs complete" ); // No output tensors: combined mode already has activations in cache // from the monolithic ONNX run. Only witness artifacts are produced here. Ok(crate::schema::execution::StrategyOutput { info: ExecutionInfo { method: ExecutionMethod::Tiled, success: true, error: None, witness_file: None, tile_exec_infos: collected, }, outputs: vec![], }) } pub(crate) fn prepare_tiles_from_cache( tiling: &TilingInfo, tensor_cache: &TensorStore, is_1d: bool, ) -> Result>>> { let all_names = tiling.all_input_names(); let mut all_tiles: Vec>> = Vec::with_capacity(all_names.len()); for name in &all_names { let input_arr = tensor_cache.get(name)?.clone(); if is_1d { let tiles = split_into_tiles_1d(&input_arr, tiling)?; all_tiles.push(tiles); } else { let input_4d = if input_arr.ndim() == 4 { let s = input_arr.shape(); Array4::from_shape_vec( (s[0], s[1], s[2], s[3]), input_arr.iter().copied().collect(), ) .map_err(|e| DsperseError::Pipeline(format!("tiling input reshape: {e}")))? } else { let input_flat: Vec = input_arr.iter().copied().collect(); let h = if tiling.h > 0 { tiling.h } else { tiling.tiles_y * tiling.tile_size }; let w = if tiling.w > 0 { tiling.w } else { tiling.tiles_x * tiling.tile_size }; reshape_to_4d(&input_flat, tiling.c_in, h, w)? }; let tiles = split_into_tiles(&input_4d, tiling)?; all_tiles.push(tiles.into_iter().map(|t| t.into_dyn()).collect()); } } Ok(all_tiles) } pub fn split_for_tiling(input: &ArrayD, tiling: &TilingInfo) -> Result>> { let is_fixed_segment = tiling.ndim == 1; if is_fixed_segment { let segment_size = tiling.segment_size.ok_or_else(|| { DsperseError::Pipeline("split_for_tiling: fixed segment missing segment_size".into()) })?; if segment_size == 0 { return Err(DsperseError::Pipeline( "split_for_tiling: segment_size must be > 0".into(), )); } let total_elements = tiling.total_elements.ok_or_else(|| { DsperseError::Pipeline("split_for_tiling: fixed segment missing total_elements".into()) })?; let flat: Vec = input.iter().copied().collect(); if flat.len() < total_elements { return Err(DsperseError::Pipeline(format!( "split_for_tiling: input has {} elements, expected at least {}", flat.len(), total_elements ))); } let num_segments = total_elements.div_ceil(segment_size); let mut segments = Vec::with_capacity(num_segments); for i in 0..num_segments { let start = i * segment_size; if start >= flat.len() { break; } let end = (start + segment_size).min(total_elements); let mut seg_data = vec![0.0f64; segment_size]; seg_data[..end - start].copy_from_slice(&flat[start..end]); segments.push( ArrayD::from_shape_vec(IxDyn(&[segment_size]), seg_data) .map_err(|e| DsperseError::Pipeline(format!("segment reshape: {e}")))?, ); } return Ok(segments); } let is_1d = tiling.ndim == 3; if is_1d { return split_into_tiles_1d(input, tiling); } let input_4d = if input.ndim() == 4 { let s = input.shape(); Array4::from_shape_vec((s[0], s[1], s[2], s[3]), input.iter().copied().collect()) .map_err(|e| DsperseError::Pipeline(format!("tiling input reshape: {e}")))? } else { let flat: Vec = input.iter().copied().collect(); let h = if tiling.h > 0 { tiling.h } else { tiling.tiles_y * tiling.tile_size }; let w = if tiling.w > 0 { tiling.w } else { tiling.tiles_x * tiling.tile_size }; reshape_to_4d(&flat, tiling.c_in, h, w)? }; let tiles = split_into_tiles(&input_4d, tiling)?; Ok(tiles.into_iter().map(|t| t.into_dyn()).collect()) } pub fn split_into_tiles(input: &Array4, tiling: &TilingInfo) -> Result>> { if tiling.halo.iter().any(|&v| v < 0) { return Err(DsperseError::Pipeline(format!( "negative halo values not supported: halo={:?}", tiling.halo ))); } let (n, c, h, w) = input.dim(); if n != 1 { return Err(DsperseError::Pipeline(format!( "split_into_tiles: batch size {n} not supported, expected 1" ))); } let halo_top = tiling.halo[0] as usize; let halo_left = tiling.halo[1] as usize; let halo_bottom = tiling.halo[2] as usize; let halo_right = tiling.halo[3] as usize; let tile_h = tiling.tile_size + halo_top + halo_bottom; let tile_w = tiling.tile_size + halo_left + halo_right; let padded_h = tiling.tiles_y * tiling.tile_size + halo_top + halo_bottom; let padded_w = tiling.tiles_x * tiling.tile_size + halo_left + halo_right; if halo_top + h > padded_h || halo_left + w > padded_w { return Err(DsperseError::Pipeline(format!( "split_into_tiles: input spatial ({h}x{w}) exceeds padded grid ({padded_h}x{padded_w})" ))); } let mut padded = Array4::::zeros((n, c, padded_h, padded_w)); padded .slice_mut(s![.., .., halo_top..halo_top + h, halo_left..halo_left + w]) .assign(input); let mut tiles = Vec::new(); for ty in 0..tiling.tiles_y { for tx in 0..tiling.tiles_x { let y_start = ty * tiling.tile_size; let x_start = tx * tiling.tile_size; let tile = padded .slice(s![ .., .., y_start..y_start + tile_h, x_start..x_start + tile_w ]) .to_owned(); tiles.push(tile); } } Ok(tiles) } pub fn reconstruct_from_tiles( tile_outputs: &[ArrayD], tiling: &TilingInfo, ) -> Result> { let expected_tiles = tiling.tiles_y * tiling.tiles_x; if tile_outputs.len() != expected_tiles { return Err(DsperseError::Pipeline(format!( "reconstruct: expected {} tiles ({}x{}), got {}", expected_tiles, tiling.tiles_y, tiling.tiles_x, tile_outputs.len() ))); } let out_h = tiling.out_tile[0].max(1) as usize; let out_w = tiling.out_tile[1].max(1) as usize; let c_out = tiling.c_out; let total_h = out_h * tiling.tiles_y; let total_w = out_w * tiling.tiles_x; let mut output = Array4::::zeros((1, c_out, total_h, total_w)); for (idx, tile_arr) in tile_outputs.iter().enumerate() { let ty = idx / tiling.tiles_x; let tx = idx % tiling.tiles_x; let tile_flat: Vec = tile_arr.iter().copied().collect(); if tile_flat.is_empty() { return Err(DsperseError::Pipeline(format!( "tile ({},{}) marked successful but produced no data", ty, tx ))); } let tile_elements = c_out * out_h * out_w; if tile_flat.len() != tile_elements { return Err(DsperseError::Pipeline(format!( "tile ({},{}) has {} elements, expected {} (c_out={}, out_h={}, out_w={})", ty, tx, tile_flat.len(), tile_elements, c_out, out_h, out_w ))); } let tile_4d = Array4::from_shape_vec((1, c_out, out_h, out_w), tile_flat.to_vec()) .map_err(|e| { DsperseError::Pipeline(format!("tile ({},{}) reshape failed: {e}", ty, tx)) })?; let y_start = ty * out_h; let x_start = tx * out_w; output .slice_mut(s![ .., .., y_start..y_start + out_h, x_start..x_start + out_w ]) .assign(&tile_4d); } Ok(output.into_dyn()) } pub(crate) fn trim_to_original_dims(arr: ArrayD, tiling: &TilingInfo) -> Result> { if tiling.h == 0 || tiling.w == 0 { return Ok(arr); } let stride_h = tiling.stride[0].max(1) as usize; let stride_w = tiling.stride[1].max(1) as usize; let expected_h = tiling.h / stride_h; let expected_w = tiling.w / stride_w; let grid_h = tiling.out_tile[0].max(1) as usize * tiling.tiles_y; let grid_w = tiling.out_tile[1].max(1) as usize * tiling.tiles_x; if grid_h > expected_h || grid_w > expected_w { if arr.ndim() != 4 { return Err(DsperseError::Pipeline(format!( "trim_to_original_dims: expected 4D array, got {}D", arr.ndim() ))); } Ok(arr .slice(s![.., .., ..expected_h, ..expected_w]) .to_owned() .into_dyn()) } else { Ok(arr) } } pub(crate) fn split_into_tiles_1d( input: &ArrayD, tiling: &TilingInfo, ) -> Result>> { let shape = input.shape(); if shape.len() != 3 { return Err(DsperseError::Pipeline(format!( "split_into_tiles_1d: expected 3D input, got {}D", shape.len() ))); } let (n, seq, _hidden) = (shape[0], shape[1], shape[2]); if n != 1 { return Err(DsperseError::Pipeline(format!( "split_into_tiles_1d: batch size {n} not supported, expected 1" ))); } let tile_size = tiling.tile_size; if tile_size == 0 || tiling.tiles_y == 0 { return Err(DsperseError::Pipeline(format!( "split_into_tiles_1d: invalid tiling config tile_size={}, tiles_y={}", tile_size, tiling.tiles_y ))); } let padded_seq = tiling .tiles_y .checked_mul(tile_size) .ok_or_else(|| DsperseError::Pipeline("split_into_tiles_1d: padded_seq overflow".into()))?; if seq > padded_seq { return Err(DsperseError::Pipeline(format!( "split_into_tiles_1d: input seq {seq} exceeds padded seq {padded_seq}" ))); } let mut padded = ArrayD::::zeros(vec![n, padded_seq, shape[2]]); padded.slice_mut(s![.., ..seq, ..]).assign(input); let mut tiles = Vec::with_capacity(tiling.tiles_y); for ty in 0..tiling.tiles_y { let start = ty * tile_size; let tile = padded .slice(s![.., start..start + tile_size, ..]) .to_owned() .into_dyn(); tiles.push(tile); } Ok(tiles) } pub(crate) fn reconstruct_from_tiles_1d( tile_outputs: &[ArrayD], tiling: &TilingInfo, ) -> Result> { if tile_outputs.is_empty() { return Err(DsperseError::Pipeline( "reconstruct_1d: no tile outputs".into(), )); } if tile_outputs.len() != tiling.tiles_y { return Err(DsperseError::Pipeline(format!( "reconstruct_1d: expected {} tiles, got {}", tiling.tiles_y, tile_outputs.len() ))); } let first = &tile_outputs[0]; if first.ndim() != 3 { return Err(DsperseError::Pipeline(format!( "reconstruct_1d: expected 3D tiles, got {}D", first.ndim() ))); } let fshape = first.shape(); let (tile_len, hidden) = (fshape[1], fshape[2]); let total_seq = tile_len * tile_outputs.len(); let mut output = ArrayD::::zeros(vec![1, total_seq, hidden]); for (idx, tile) in tile_outputs.iter().enumerate() { if tile.shape() != fshape { return Err(DsperseError::Pipeline(format!( "reconstruct_1d: tile {idx} shape {:?} != first tile shape {:?}", tile.shape(), fshape ))); } let start = idx * tile_len; output .slice_mut(s![.., start..start + tile_len, ..]) .assign(tile); } Ok(output) } pub(crate) fn trim_to_original_seq(arr: ArrayD, tiling: &TilingInfo) -> Result> { if tiling.h == 0 { return Ok(arr); } if arr.ndim() != 3 { return Err(DsperseError::Pipeline(format!( "trim_to_original_seq: expected 3D array, got {}D", arr.ndim() ))); } let current_seq = arr.shape()[1]; if current_seq > tiling.h { Ok(arr.slice(s![.., ..tiling.h, ..]).to_owned().into_dyn()) } else { Ok(arr) } } pub(crate) fn prepare_fixed_segments_from_cache( tiling: &TilingInfo, tensor_cache: &TensorStore, ) -> Result>>> { let segment_size = tiling.segment_size.ok_or_else(|| { DsperseError::Pipeline("fixed segment tiling missing segment_size".into()) })?; if segment_size == 0 { return Err(DsperseError::Pipeline( "fixed segment tiling has segment_size=0".into(), )); } let total_elements = tiling.total_elements.ok_or_else(|| { DsperseError::Pipeline("fixed segment tiling missing total_elements".into()) })?; let all_names = tiling.all_input_names(); let num_segments = total_elements.div_ceil(segment_size); let mut all_segments: Vec>> = Vec::with_capacity(all_names.len()); for name in &all_names { let input_arr = tensor_cache.get(name)?.clone(); let flat: Vec = input_arr.iter().copied().collect(); if flat.len() < total_elements { return Err(DsperseError::Pipeline(format!( "fixed segment: input '{}' has {} elements, expected at least {}", name, flat.len(), total_elements ))); } let mut segments = Vec::with_capacity(num_segments); for i in 0..num_segments { let start = i * segment_size; let end = (start + segment_size).min(total_elements); let mut seg_data = vec![0.0f64; segment_size]; seg_data[..end - start].copy_from_slice(&flat[start..end]); let seg = ArrayD::from_shape_vec(IxDyn(&[segment_size]), seg_data) .map_err(|e| DsperseError::Pipeline(format!("fixed segment reshape: {e}")))?; segments.push(seg); } all_segments.push(segments); } Ok(all_segments) } pub(crate) fn reconstruct_from_fixed_segments( segment_outputs: &[ArrayD], tiling: &TilingInfo, ) -> Result> { let total_elements = tiling.total_elements.ok_or_else(|| { DsperseError::Pipeline("reconstruct fixed segments: missing total_elements".into()) })?; if segment_outputs.is_empty() { return Err(DsperseError::Pipeline( "reconstruct fixed segments: no outputs".into(), )); } let mut flat = Vec::with_capacity(total_elements); for seg in segment_outputs { flat.extend(seg.iter().copied()); } flat.truncate(total_elements); let shape: Vec = if tiling.original_shape.is_empty() { vec![total_elements] } else { tiling.original_shape.iter().map(|&d| d as usize).collect() }; ArrayD::from_shape_vec(IxDyn(&shape), flat) .map_err(|e| DsperseError::Pipeline(format!("reconstruct fixed segments reshape: {e}"))) } pub(crate) fn reshape_to_4d(flat: &[f64], c: usize, h: usize, w: usize) -> Result> { let n = 1usize; let total = flat.len(); if n * c * h * w != total { return Err(DsperseError::Pipeline(format!( "cannot reshape {total} elements to 4D (n={n}, c={c}, h={h}, w={w})" ))); } Array4::from_shape_vec((n, c, h, w), flat.to_vec()) .map_err(|e| DsperseError::Pipeline(format!("reshape: {e}"))) } pub(crate) fn flatten_tile_inputs(all_tiles: &[Vec>], tile_idx: usize) -> Vec { let total: usize = all_tiles.iter().map(|tiles| tiles[tile_idx].len()).sum(); let mut flat = Vec::with_capacity(total); for input_tiles in all_tiles { flat.extend(input_tiles[tile_idx].iter().copied()); } flat } ================================================ FILE: crates/dsperse/src/pipeline/verifier.rs ================================================ use std::path::Path; use crate::backend::ProofBackend; use crate::error::Result; use crate::schema::execution::RunMetadata; use super::stage::{PipelineStage, run_pipeline_stage}; pub fn verify_run( run_dir: &Path, slices_dir: &Path, backend: &dyn ProofBackend, parallel: usize, ) -> Result { run_pipeline_stage( PipelineStage::Verify, run_dir, slices_dir, backend, parallel, ) } ================================================ FILE: crates/dsperse/src/python.rs ================================================ use std::path::PathBuf; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use crate::backend::jstprove::JstproveBackend; use crate::error::DsperseError; use crate::pipeline::{self, RunConfig}; use jstprove_circuits::api::{ProofSystemParseError, ProofSystemType as ProofSystem}; fn to_py_err(e: DsperseError) -> PyErr { let msg = e.to_string(); match e { DsperseError::Io { .. } => pyo3::exceptions::PyIOError::new_err(msg), DsperseError::MsgpackEncode(_) | DsperseError::MsgpackDecode(_) => { pyo3::exceptions::PyValueError::new_err(msg) } DsperseError::Archive(_) | DsperseError::Metadata(_) => { pyo3::exceptions::PyValueError::new_err(msg) } DsperseError::Onnx(_) | DsperseError::Backend(_) | DsperseError::Slicer(_) | DsperseError::Pipeline(_) | DsperseError::Other(_) => PyRuntimeError::new_err(msg), } } fn to_pretty_json(value: &T) -> PyResult { serde_json::to_string_pretty(value).map_err(|e| { to_py_err(DsperseError::Other(format!( "pretty-json serialization failed: {e}" ))) }) } fn resolve_ops(proof_system: &str, circuit_ops: Option<&[String]>) -> PyResult> { let ps: ProofSystem = proof_system .parse() .map_err(|e: ProofSystemParseError| PyRuntimeError::new_err(e.to_string()))?; let supported = ps.supported_ops(); match circuit_ops { None => Ok(supported.iter().map(|s| (*s).to_string()).collect()), Some(ops) => { for op in ops { if !supported.contains(&op.as_str()) { return Err(PyRuntimeError::new_err(format!( "op {op:?} not supported by proof system {ps}. Supported: {supported:?}" ))); } } Ok(ops.to_vec()) } } } fn require_nonzero(parallel: usize) -> PyResult<()> { if parallel == 0 { return Err(pyo3::exceptions::PyValueError::new_err( "parallel must be > 0", )); } Ok(()) } #[pyfunction] #[pyo3(signature = (model_path, output_dir=None, tile_size=None, proof_system="expander", circuit_ops=None, input_shape=None))] fn slice_model( py: Python<'_>, model_path: &str, output_dir: Option<&str>, tile_size: Option, proof_system: &str, circuit_ops: Option>, input_shape: Option>, ) -> PyResult { let model = PathBuf::from(model_path); let out = output_dir.map(PathBuf::from); let ops = resolve_ops(proof_system, circuit_ops.as_deref())?; let ops_refs: Vec<&str> = ops.iter().map(String::as_str).collect(); let metadata = py .allow_threads(|| { crate::slicer::slice_model( &model, out.as_deref(), tile_size, &ops_refs, input_shape.as_deref(), ) }) .map_err(to_py_err)?; to_pretty_json(&metadata) } #[pyfunction] #[allow(clippy::too_many_arguments)] #[pyo3(signature = (slices_dir, proof_config="bn254_raw", parallel=1, weights_as_inputs=true, layers=None, proof_system="expander", circuit_ops=None, skip_compile_over_size=None, holographic=false))] fn compile_slices( py: Python<'_>, slices_dir: &str, proof_config: &str, parallel: usize, weights_as_inputs: bool, layers: Option>, proof_system: &str, circuit_ops: Option>, skip_compile_over_size: Option, holographic: bool, ) -> PyResult<()> { require_nonzero(parallel)?; let backend = JstproveBackend::default(); let parsed_config: jstprove_circuits::api::ProofConfigType = proof_config .parse() .map_err(|e: jstprove_circuits::api::ProofConfigError| { pyo3::exceptions::PyValueError::new_err(e.to_string()) })?; let dir = PathBuf::from(slices_dir); let ops = resolve_ops(proof_system, circuit_ops.as_deref())?; let ops_refs: Vec<&str> = ops.iter().map(String::as_str).collect(); let report = py .allow_threads(|| { pipeline::compile_slices( &dir, &backend, parsed_config, parallel, weights_as_inputs, layers.as_deref(), &ops_refs, skip_compile_over_size, holographic, ) }) .map_err(to_py_err)?; // Propagate partial-compile failures to the Python caller so // silent non-zero masks become impossible, but phrase the // message in Python-binding terms rather than reusing the // CLI's --allow-onnx-fallback hint (the Python API has its // own opt-in path). if !report.failed.is_empty() { return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( "partial compile failures: {} slice(s) failed to compile; catch this exception and continue to accept ONNX fallback for the failed slices, or inspect the Rust-side logs for the per-slice error payload", report.failed.len() ))); } Ok(()) } #[pyfunction] #[allow(clippy::too_many_arguments)] #[pyo3(signature = (slices_dir, input_file, run_dir, parallel=1, batch=false, weights_onnx=None, combined=true))] fn run_inference( py: Python<'_>, slices_dir: &str, input_file: &str, run_dir: &str, parallel: usize, batch: bool, weights_onnx: Option<&str>, combined: bool, ) -> PyResult { require_nonzero(parallel)?; let backend = JstproveBackend::default(); let config = RunConfig { parallel, batch, weights_onnx: weights_onnx.map(PathBuf::from), combined, }; let sd = PathBuf::from(slices_dir); let inf = PathBuf::from(input_file); let rd = PathBuf::from(run_dir); let metadata = py .allow_threads(|| pipeline::run_inference(&sd, &inf, &rd, &backend, &config)) .map_err(to_py_err)?; to_pretty_json(&metadata) } #[pyfunction] #[pyo3(signature = (run_dir, slices_dir, parallel=1))] fn prove_run(py: Python<'_>, run_dir: &str, slices_dir: &str, parallel: usize) -> PyResult { require_nonzero(parallel)?; let backend = JstproveBackend::default(); let rd = PathBuf::from(run_dir); let sd = PathBuf::from(slices_dir); let metadata = py .allow_threads(|| pipeline::prove_run(&rd, &sd, &backend, parallel)) .map_err(to_py_err)?; to_pretty_json(&metadata) } #[pyfunction] #[pyo3(signature = (run_dir, slices_dir, parallel=1))] fn verify_run( py: Python<'_>, run_dir: &str, slices_dir: &str, parallel: usize, ) -> PyResult { require_nonzero(parallel)?; let backend = JstproveBackend::default(); let rd = PathBuf::from(run_dir); let sd = PathBuf::from(slices_dir); let metadata = py .allow_threads(|| pipeline::verify_run(&rd, &sd, &backend, parallel)) .map_err(to_py_err)?; to_pretty_json(&metadata) } #[pyfunction] #[pyo3(signature = (argv=None))] fn cli_main(py: Python<'_>, argv: Option>) -> PyResult<()> { use clap::Parser; use tracing_subscriber::EnvFilter; let cli = match argv { Some(args) => crate::cli::Cli::try_parse_from(args.clone()).or_else(|_| { let mut with_prog = vec!["dsperse".to_string()]; with_prog.extend(args); crate::cli::Cli::try_parse_from(with_prog) }), None => crate::cli::Cli::try_parse(), } .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let _ = tracing_subscriber::fmt() .with_env_filter( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)), ) .try_init(); eprintln!("dsperse {}", crate::cli::VERSION); let result = py.allow_threads(|| crate::cli::dispatch(cli.command)); result.map_err(to_py_err) } #[pyfunction] #[pyo3(signature = (slices_dir, parallel=1, overwrite=false))] fn setup_holographic( py: Python<'_>, slices_dir: &str, parallel: usize, overwrite: bool, ) -> PyResult<()> { require_nonzero(parallel)?; let backend = JstproveBackend::default(); let dir = PathBuf::from(slices_dir); let report = py .allow_threads(|| { pipeline::setup_holographic_for_slices(&dir, &backend, parallel, overwrite) }) .map_err(to_py_err)?; if !report.failed.is_empty() { return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( "{} bundle(s) failed holographic setup", report.failed.len() ))); } Ok(()) } #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(slice_model, m)?)?; m.add_function(wrap_pyfunction!(compile_slices, m)?)?; m.add_function(wrap_pyfunction!(run_inference, m)?)?; m.add_function(wrap_pyfunction!(prove_run, m)?)?; m.add_function(wrap_pyfunction!(verify_run, m)?)?; m.add_function(wrap_pyfunction!(setup_holographic, m)?)?; m.add_function(wrap_pyfunction!(cli_main, m)?)?; Ok(()) } ================================================ FILE: crates/dsperse/src/schema/execution.rs ================================================ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use super::metadata::{BackendKind, RunSliceMetadata}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ExecutionMethod { JstproveGenWitness, OnnxOnly, Tiled, ChannelSplit, DimSplit, JstproveProve, JstproveVerify, } impl std::fmt::Display for ExecutionMethod { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::JstproveGenWitness => write!(f, "jstprove_gen_witness"), Self::OnnxOnly => write!(f, "onnx_only"), Self::Tiled => write!(f, "tiled"), Self::ChannelSplit => write!(f, "channel_split"), Self::DimSplit => write!(f, "dim_split"), Self::JstproveProve => write!(f, "jstprove_prove"), Self::JstproveVerify => write!(f, "jstprove_verify"), } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TileResult { pub tile_idx: usize, pub success: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub error: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub method: Option, #[serde(default, skip_serializing_if = "is_zero")] pub time_sec: f64, #[serde(default, skip_serializing_if = "Option::is_none")] pub proof_path: Option, } impl TileResult { pub fn failure( tile_idx: usize, error: String, method: Option, time_sec: f64, ) -> Self { Self { tile_idx, success: false, error: Some(error), method, time_sec, proof_path: None, } } pub fn success(tile_idx: usize, method: Option, time_sec: f64) -> Self { Self { tile_idx, success: true, error: None, method, time_sec, proof_path: None, } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionInfo { pub method: ExecutionMethod, #[serde(default)] pub success: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub error: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub witness_file: Option, #[serde(default, skip_serializing_if = "Vec::is_empty", alias = "tiles")] pub tile_exec_infos: Vec, } #[derive(Debug)] pub struct StrategyOutput { pub info: ExecutionInfo, pub outputs: Vec<(String, ndarray::ArrayD)>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SliceResult { pub slice_id: String, pub success: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub method: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub error: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub proof_path: Option, #[serde(default, skip_serializing_if = "is_zero")] pub time_sec: f64, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tiles: Vec, } impl SliceResult { pub fn failure( slice_id: impl Into, method: ExecutionMethod, error: String, time_sec: f64, ) -> Self { Self { slice_id: slice_id.into(), success: false, method: Some(method), error: Some(error), proof_path: None, time_sec, tiles: Vec::new(), } } pub fn success(slice_id: impl Into, method: ExecutionMethod, time_sec: f64) -> Self { Self { slice_id: slice_id.into(), success: true, method: Some(method), error: None, proof_path: None, time_sec, tiles: Vec::new(), } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionNode { pub slice_id: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub primary: Option, #[serde(default)] pub fallbacks: Vec, #[serde(default)] pub use_circuit: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub next: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub circuit_path: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub onnx_path: Option, #[serde(default)] pub backend: BackendKind, } impl Default for ExecutionNode { fn default() -> Self { Self { slice_id: String::new(), primary: None, fallbacks: Vec::new(), use_circuit: false, next: None, circuit_path: None, onnx_path: None, backend: BackendKind::Onnx, } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionResultEntry { pub slice_id: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub witness_execution: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub proof_execution: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub verification_execution: Option, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ExecutionChain { #[serde(default, skip_serializing_if = "Option::is_none")] pub head: Option, #[serde(default)] pub nodes: HashMap, #[serde(default)] pub fallback_map: HashMap>, #[serde(default)] pub execution_results: Vec, #[serde(default)] pub jstprove_proved_slices: usize, #[serde(default)] pub jstprove_verified_slices: usize, } impl ExecutionChain { pub fn get_result_for_slice(&self, slice_id: &str) -> Option<&ExecutionResultEntry> { self.execution_results .iter() .find(|e| e.slice_id == slice_id) } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct RunMetadata { #[serde(default)] pub slices: HashMap, #[serde(default)] pub execution_chain: ExecutionChain, #[serde(default, skip_serializing_if = "Option::is_none")] pub packaging_type: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub source_path: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub run_directory: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub model_path: Option, } impl RunMetadata { pub fn get_slice(&self, slice_id: &str) -> Option<&RunSliceMetadata> { self.slices.get(slice_id) } pub fn iter_circuit_slices(&self) -> impl Iterator { self.execution_chain .nodes .iter() .filter(|(_, node)| node.use_circuit) .filter_map(|(slice_id, _)| { self.slices .get(slice_id) .map(|meta| (slice_id.as_str(), meta)) }) } } fn is_zero(v: &f64) -> bool { *v == 0.0 } ================================================ FILE: crates/dsperse/src/schema/metadata.rs ================================================ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use super::tiling::{ChannelSplitInfo, DimSplitInfo, TilingInfo}; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum BackendKind { #[serde(alias = "JSTPROVE")] Jstprove, #[default] Onnx, } impl std::fmt::Display for BackendKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Jstprove => write!(f, "jstprove"), Self::Onnx => write!(f, "onnx"), } } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct TensorShape { #[serde(default)] pub input: Vec>, #[serde(default)] pub output: Vec>, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Dependencies { #[serde(default)] pub input: Vec, #[serde(default)] pub output: Vec, #[serde(default)] pub filtered_inputs: Vec, } #[derive(Debug, Clone, Default, Deserialize)] pub struct CompilationFiles { #[serde(default, alias = "compiled_circuit", alias = "circuit")] pub compiled: Option, #[serde(default)] pub settings: Option, #[serde(default)] pub pk_key: Option, #[serde(default)] pub vk_key: Option, } #[derive(Debug, Clone, Default, Deserialize)] pub struct BackendCompilation { #[serde(default)] pub compiled: bool, #[serde(default)] pub tiled: bool, #[serde(default)] pub weights_as_inputs: bool, #[serde(default)] pub files: CompilationFiles, #[serde(default)] pub compilation_timestamp: Option, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(default)] pub struct Compilation { #[serde(skip_serializing)] pub jstprove: BackendCompilation, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct SliceShapeWrapper { #[serde(default)] pub tensor_shape: TensorShape, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct SliceMetadata { #[serde(default)] pub index: usize, #[serde(default)] pub filename: String, #[serde(default)] pub path: String, #[serde(default)] pub relative_path: String, #[serde(default)] pub shape: SliceShapeWrapper, #[serde(default)] pub dependencies: Dependencies, #[serde(default, skip_serializing_if = "Option::is_none")] pub tiling: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub channel_split: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub dim_split: Option, #[serde(default)] pub compilation: Compilation, #[serde(default, skip_serializing_if = "Option::is_none")] pub slice_metadata: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub slice_metadata_relative_path: Option, } impl SliceMetadata { pub fn split_strategy(&self) -> Option> { use super::tiling::SplitStrategy; self.tiling .as_ref() .map(SplitStrategy::Tiled) .or_else(|| self.channel_split.as_ref().map(SplitStrategy::ChannelSplit)) .or_else(|| self.dim_split.as_ref().map(SplitStrategy::DimSplit)) } pub fn output_names(&self) -> &[String] { &self.dependencies.output } pub fn resolve_onnx( &self, slices_dir: &std::path::Path, ) -> crate::error::Result { if self.relative_path.is_empty() { Ok(slices_dir.join("model.onnx")) } else { crate::utils::paths::resolve_relative_path(slices_dir, &self.relative_path) } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RunSliceMetadata { #[serde(default)] pub path: String, #[serde(default)] pub input_shape: Vec>, #[serde(default)] pub output_shape: Vec>, #[serde(default)] pub dependencies: Dependencies, #[serde(default, skip_serializing_if = "Option::is_none")] pub tiling: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub channel_split: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub dim_split: Option, #[serde(default)] pub backend: BackendKind, #[serde( default, skip_serializing_if = "Option::is_none", alias = "circuit_path" )] pub jstprove_circuit_path: Option, #[serde( default, skip_serializing_if = "Option::is_none", alias = "settings_path" )] pub jstprove_settings_path: Option, } impl RunSliceMetadata { pub fn split_strategy(&self) -> Option> { use super::tiling::SplitStrategy; self.tiling .as_ref() .map(SplitStrategy::Tiled) .or_else(|| self.channel_split.as_ref().map(SplitStrategy::ChannelSplit)) .or_else(|| self.dim_split.as_ref().map(SplitStrategy::DimSplit)) } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ModelMetadata { #[serde(default)] pub original_model: String, #[serde(default)] pub model_type: String, #[serde(default)] pub input_shape: Vec>, #[serde(default)] pub output_shapes: Vec>, #[serde(default)] pub output_names: Vec, #[serde(default)] pub slice_points: Vec, #[serde(default)] pub slices: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] pub dsperse_version: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub dsperse_rev: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub jstprove_version: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub jstprove_rev: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub traced_shapes: Option>>, #[serde(default, skip_serializing_if = "Option::is_none")] pub traced_types: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub original_model_path: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub folded_constant_names: Vec, } impl ModelMetadata { pub fn load(path: &std::path::Path) -> crate::error::Result { let data = crate::utils::limits::read_checked(path)?; rmp_serde::from_slice(&data).map_err(Into::into) } pub fn save(&self, path: &std::path::Path) -> crate::error::Result<()> { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent) .map_err(|e| crate::error::DsperseError::io(e, parent))?; } let data = rmp_serde::to_vec_named(self)?; let tmp_path = path.with_extension("msgpack.tmp"); std::fs::write(&tmp_path, &data) .map_err(|e| crate::error::DsperseError::io(e, &tmp_path))?; std::fs::rename(&tmp_path, path).map_err(|e| crate::error::DsperseError::io(e, path)) } pub fn stamp_version(&mut self) { let ver = crate::version::dsperse_artifact_version(); self.dsperse_version = Some(ver.dsperse_version); self.dsperse_rev = ver.dsperse_rev; self.jstprove_version = Some(ver.jstprove_version); self.jstprove_rev = ver.jstprove_rev; } } ================================================ FILE: crates/dsperse/src/schema/mod.rs ================================================ pub mod execution; pub mod metadata; pub mod tiling; pub use execution::*; pub use metadata::*; pub use tiling::*; ================================================ FILE: crates/dsperse/src/schema/tiling.rs ================================================ use serde::{self, Deserialize, Deserializer, Serialize}; #[derive(Debug, Clone)] pub enum SplitStrategy<'a> { Tiled(&'a TilingInfo), ChannelSplit(&'a ChannelSplitInfo), DimSplit(&'a DimSplitInfo), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TileInfo { #[serde(default)] pub path: String, #[serde(default = "default_pair_zero")] pub conv_out: [i64; 2], #[serde(default, skip_serializing_if = "Option::is_none")] pub jstprove_circuit_path: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TilingInfo { #[serde(default)] pub slice_idx: usize, #[serde(default)] pub tile_size: usize, #[serde(default = "default_one")] pub num_tiles: usize, #[serde(default = "default_one")] pub tiles_y: usize, #[serde(default = "default_one")] pub tiles_x: usize, #[serde(default = "default_quad_zero", deserialize_with = "deserialize_halo")] pub halo: [i64; 4], #[serde(default = "default_pair_zero")] pub out_tile: [i64; 2], #[serde(default = "default_pair_one")] pub stride: [i64; 2], #[serde(default)] pub c_in: usize, #[serde(default)] pub c_out: usize, #[serde(default = "default_input_name")] pub input_name: String, #[serde(default = "default_output_name")] pub output_name: String, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub input_names: Vec, #[serde(default = "default_four")] pub ndim: usize, #[serde(default)] pub h: usize, #[serde(default)] pub w: usize, #[serde(default, skip_serializing_if = "Option::is_none")] pub tile: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tiles: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub segment_size: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub total_elements: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub original_shape: Vec, } impl TilingInfo { pub fn all_input_names(&self) -> Vec<&str> { if self.input_names.is_empty() { vec![&self.input_name] } else { self.input_names.iter().map(|s| s.as_str()).collect() } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChannelGroupInfo { #[serde(default)] pub group_idx: usize, #[serde(default)] pub c_start: usize, #[serde(default)] pub c_end: usize, #[serde(default)] pub path: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub jstprove_circuit_path: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub jstprove_settings_path: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChannelSplitInfo { #[serde(default)] pub slice_idx: usize, #[serde(default)] pub c_in: usize, #[serde(default)] pub c_out: usize, #[serde(default = "default_one")] pub num_groups: usize, #[serde(default)] pub channels_per_group: usize, #[serde(default = "default_input_name")] pub input_name: String, #[serde(default = "default_output_name")] pub output_name: String, #[serde(default)] pub h: usize, #[serde(default)] pub w: usize, #[serde(default)] pub out_h: usize, #[serde(default)] pub out_w: usize, #[serde(default)] pub groups: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] pub bias_path: Option, } fn default_one() -> usize { 1 } fn default_four() -> usize { 4 } fn default_pair_zero() -> [i64; 2] { [0, 0] } fn default_pair_one() -> [i64; 2] { [1, 1] } fn default_quad_zero() -> [i64; 4] { [0, 0, 0, 0] } fn deserialize_halo<'de, D>(deserializer: D) -> std::result::Result<[i64; 4], D::Error> where D: Deserializer<'de>, { let v: Vec = Vec::deserialize(deserializer)?; match v.len() { 2 => Ok([v[0], v[1], v[0], v[1]]), 4 => Ok([v[0], v[1], v[2], v[3]]), _ => Err(serde::de::Error::custom(format!( "expected 2 or 4 elements for halo, got {}", v.len() ))), } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum DimSplitKind { #[default] MatMulOutputDim, HeadDim, BatchDim, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct DimSplitInfo { #[serde(default)] pub slice_idx: usize, #[serde(default)] pub split_kind: DimSplitKind, #[serde(default)] pub split_dim: usize, #[serde(default)] pub dim_size: usize, #[serde(default = "default_one")] pub num_groups: usize, #[serde(default)] pub elements_per_group: usize, #[serde(default = "default_input_name")] pub input_name: String, #[serde(default = "default_output_name")] pub output_name: String, #[serde(default)] pub concat_axis: usize, #[serde(default)] pub estimated_group_constraints: u64, #[serde(default, skip_serializing_if = "Option::is_none")] pub weight_name: Option, #[serde(default)] pub k_dim: usize, #[serde(default)] pub n_dim: usize, #[serde(default = "default_one")] pub k_chunks: usize, #[serde(default, skip_serializing_if = "Option::is_none")] pub template_path: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub jstprove_circuit_path: Option, } impl DimSplitInfo { pub fn from_detection( d: &crate::slicer::autotiler::DimSplitDetection, slice_idx: usize, template_path: Option, ) -> Self { let estimated_group_constraints = if d.k_chunks > 1 { (d.k_dim.div_ceil(d.k_chunks) * d.n_dim * 2) as u64 } else if d.num_groups > 0 { d.estimated_constraints / d.num_groups as u64 } else { d.estimated_constraints }; Self { slice_idx, split_kind: d.split_kind.clone(), split_dim: d.split_dim, dim_size: d.dim_size, num_groups: d.num_groups, elements_per_group: d.elements_per_group, input_name: d.input_name.clone(), output_name: d.output_name.clone(), concat_axis: d.concat_axis, estimated_group_constraints, weight_name: d.weight_name.clone(), k_dim: d.k_dim, n_dim: d.n_dim, k_chunks: d.k_chunks, template_path, jstprove_circuit_path: None, } } } fn default_input_name() -> String { "input".to_string() } fn default_output_name() -> String { "output".to_string() } ================================================ FILE: crates/dsperse/src/slicer/analyzer.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::Path; use serde::{Deserialize, Serialize}; use super::onnx_proto::{self, GraphProto, ModelProto, TensorProto}; use crate::error::{DsperseError, Result}; use crate::schema::metadata::Dependencies; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeAnalysis { pub index: usize, pub slice_name: String, pub node_type: String, pub parameter_details: HashMap, pub dependencies: NodeDependencies, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ParameterDetail { pub shape: Vec, pub size: usize, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeDependencies { pub input: Vec, pub output: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AnalysisResult { pub original_model: Option, pub model_type: String, pub node_count: usize, pub initializer_count: usize, pub input_shape: Vec>, pub output_shapes: Vec>, pub output_names: Vec, pub opset_version: Option, pub nodes: HashMap, pub initializer_names: HashSet, } pub fn analyze(model: &ModelProto, onnx_path: Option<&Path>) -> Result { let graph = model .graph .as_ref() .ok_or_else(|| DsperseError::Onnx("model has no graph".into()))?; let initializer_map: HashMap<&str, &TensorProto> = graph .initializer .iter() .map(|i| (i.name.as_str(), i)) .collect(); let input_shapes = get_model_input_shapes(graph, &initializer_map); let output_shapes = get_model_output_shapes(graph); let output_names = get_model_output_names(graph); let mut nodes = HashMap::new(); for (i, node) in graph.node.iter().enumerate() { let node_key = if node.name.is_empty() { format!("{}_{}", node.op_type, i) } else { node.name.clone() }; let parameter_details = get_parameter_details(node, &initializer_map); let mut inputs: Vec = node .input .iter() .filter(|s| !s.is_empty()) .cloned() .collect(); if super::is_control_flow(&node.op_type) { let outer_refs = super::collect_subgraph_outer_refs(node, graph); for r in outer_refs { if !inputs.contains(&r) { inputs.push(r); } } } nodes.insert( node_key, NodeAnalysis { index: i, slice_name: format!("{}_{}", node.op_type, i), node_type: node.op_type.clone(), parameter_details, dependencies: NodeDependencies { input: inputs, output: node.output.clone(), }, }, ); } let opset_version = model .opset_import .iter() .find(|o| o.domain.is_empty() || o.domain == "ai.onnx") .map(|o| o.version); if let Some(v) = opset_version && v < 18 { tracing::warn!(opset = v, "opset < 18 detected; continuing anyway"); } let initializer_names: HashSet = graph.initializer.iter().map(|i| i.name.clone()).collect(); Ok(AnalysisResult { original_model: onnx_path.map(|p| p.to_string_lossy().to_string()), model_type: "ONNX".to_string(), node_count: graph.node.len(), initializer_count: graph.initializer.len(), input_shape: input_shapes, output_shapes, output_names, opset_version, nodes, initializer_names, }) } fn get_model_input_shapes( graph: &GraphProto, initializer_map: &HashMap<&str, &TensorProto>, ) -> Vec> { graph .input .iter() .filter(|inp| !initializer_map.contains_key(inp.name.as_str())) .map(onnx_proto::vi_shape) .collect() } fn get_model_output_shapes(graph: &GraphProto) -> Vec> { graph.output.iter().map(onnx_proto::vi_shape).collect() } fn get_model_output_names(graph: &GraphProto) -> Vec { graph.output.iter().map(|o| o.name.clone()).collect() } fn get_parameter_details( node: &onnx_proto::NodeProto, initializer_map: &HashMap<&str, &TensorProto>, ) -> HashMap { let mut details = HashMap::new(); if !matches!(node.op_type.as_str(), "Conv" | "Gemm" | "MatMul") { return details; } for inp_name in &node.input { if let Some(init) = initializer_map.get(inp_name.as_str()) { let size: usize = init.dims.iter().map(|&d| d as usize).product(); if size > 0 { details.insert( inp_name.clone(), ParameterDetail { shape: init.dims.clone(), size, }, ); } } } details } pub fn get_segment_dependencies( analysis: &AnalysisResult, start_idx: usize, end_idx: usize, ) -> Dependencies { let mut inputs = Vec::new(); let mut output_map: HashMap = HashMap::new(); let mut sorted_nodes: Vec<&NodeAnalysis> = analysis .nodes .values() .filter(|n| n.index >= start_idx && n.index < end_idx) .collect(); sorted_nodes.sort_by_key(|n| n.index); let mut consumed_in_segment: HashSet = HashSet::new(); for node in &sorted_nodes { for out in &node.dependencies.output { output_map.insert(out.clone(), true); } for inp in &node.dependencies.input { if output_map.contains_key(inp) { consumed_in_segment.insert(inp.clone()); } if !output_map.contains_key(inp) && !inputs.contains(inp) { inputs.push(inp.clone()); } } } let model_output_set: HashSet<&str> = analysis.output_names.iter().map(|s| s.as_str()).collect(); let mut outputs: Vec = output_map .keys() .filter(|output| { if inputs.contains(output) { return false; } // Exclude tensors consumed by a later node in the same segment // unless they are also model-level final outputs. The materializer // only promotes internally-consumed tensors to graph outputs when // a downstream segment needs them; the metadata list must match. if consumed_in_segment.contains(output.as_str()) && !model_output_set.contains(output.as_str()) { return false; } true }) .cloned() .collect(); outputs.sort(); let filtered = inputs .iter() .filter(|name| !analysis.initializer_names.contains(name.as_str())) .cloned() .collect::>(); let filtered_inputs = if filtered.is_empty() && !inputs.is_empty() { vec![inputs[0].clone()] } else { filtered }; Dependencies { input: inputs, output: outputs, filtered_inputs, } } #[cfg(test)] mod tests { use super::*; fn make_node( op: &str, idx: usize, inputs: Vec<&str>, outputs: Vec<&str>, ) -> onnx_proto::NodeProto { onnx_proto::NodeProto { op_type: op.into(), name: format!("{}_{}", op, idx), input: inputs.into_iter().map(String::from).collect(), output: outputs.into_iter().map(String::from).collect(), attribute: vec![], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], } } fn make_model_with_nodes(nodes: Vec) -> ModelProto { let input = onnx_proto::make_tensor_value_info("x", TensorProto::FLOAT, &[1, 3, 8, 8]); let output = onnx_proto::make_tensor_value_info("y", TensorProto::FLOAT, &[1, 3, 8, 8]); let graph = onnx_proto::make_graph("test", nodes, vec![input], vec![output], vec![]); onnx_proto::make_model(graph, 13) } fn make_model_with_initializers( nodes: Vec, initializers: Vec, ) -> ModelProto { let input = onnx_proto::make_tensor_value_info("x", TensorProto::FLOAT, &[1, 3, 8, 8]); let output = onnx_proto::make_tensor_value_info("y", TensorProto::FLOAT, &[1, 3, 8, 8]); let graph = onnx_proto::make_graph("test", nodes, vec![input], vec![output], initializers); onnx_proto::make_model(graph, 13) } #[test] fn analyze_empty_model() { let model = make_model_with_nodes(vec![]); let result = analyze(&model, None).unwrap(); assert_eq!(result.node_count, 0); assert!(result.nodes.is_empty()); assert_eq!(result.model_type, "ONNX"); } #[test] fn analyze_single_relu() { let model = make_model_with_nodes(vec![make_node("Relu", 0, vec!["x"], vec!["y"])]); let result = analyze(&model, None).unwrap(); assert_eq!(result.node_count, 1); let node = result.nodes.values().next().unwrap(); assert_eq!(node.node_type, "Relu"); assert!(node.parameter_details.is_empty()); } #[test] fn analyze_conv_with_initializer() { let weight_data: Vec = vec![1.0; 27]; let weight_tensor = onnx_proto::make_tensor( "conv_weight", TensorProto::FLOAT, &[1, 3, 3, 3], weight_data, ); let conv = make_node("Conv", 0, vec!["x", "conv_weight"], vec!["y"]); let model = make_model_with_initializers(vec![conv], vec![weight_tensor]); let result = analyze(&model, None).unwrap(); assert_eq!(result.initializer_count, 1); let node = result.nodes.values().next().unwrap(); assert!(!node.parameter_details.is_empty()); let detail = node.parameter_details.get("conv_weight").unwrap(); assert_eq!(detail.shape, vec![1, 3, 3, 3]); assert_eq!(detail.size, 27); } #[test] fn analyze_non_param_op_has_no_details() { let weight_data: Vec = vec![1.0; 27]; let weight_tensor = onnx_proto::make_tensor("add_weight", TensorProto::FLOAT, &[1, 3, 3, 3], weight_data); let add = make_node("Add", 0, vec!["x", "add_weight"], vec!["y"]); let model = make_model_with_initializers(vec![add], vec![weight_tensor]); let result = analyze(&model, None).unwrap(); let node = result.nodes.values().next().unwrap(); assert!(node.parameter_details.is_empty()); } #[test] fn analyze_model_no_graph() { let model = ModelProto { graph: None, ..Default::default() }; assert!(analyze(&model, None).is_err()); } #[test] fn analyze_dependencies_tracked() { let conv = make_node("Conv", 0, vec!["x", "w"], vec!["conv_out"]); let relu = make_node("Relu", 1, vec!["conv_out"], vec!["y"]); let model = make_model_with_nodes(vec![conv, relu]); let result = analyze(&model, None).unwrap(); assert_eq!(result.node_count, 2); let relu_node = result .nodes .values() .find(|n| n.node_type == "Relu") .unwrap(); assert_eq!(relu_node.dependencies.input, vec!["conv_out"]); assert_eq!(relu_node.dependencies.output, vec!["y"]); } #[test] fn analyze_unnamed_nodes_get_generated_keys() { let mut node = make_node("Relu", 0, vec!["x"], vec!["y"]); node.name = String::new(); let model = make_model_with_nodes(vec![node]); let result = analyze(&model, None).unwrap(); assert!(result.nodes.contains_key("Relu_0")); } #[test] fn get_segment_dependencies_basic() { let mut nodes = HashMap::new(); nodes.insert( "conv".into(), NodeAnalysis { index: 0, slice_name: "Conv_0".into(), node_type: "Conv".into(), parameter_details: HashMap::new(), dependencies: NodeDependencies { input: vec!["x".into(), "w".into()], output: vec!["conv_out".into()], }, }, ); nodes.insert( "relu".into(), NodeAnalysis { index: 1, slice_name: "Relu_1".into(), node_type: "Relu".into(), parameter_details: HashMap::new(), dependencies: NodeDependencies { input: vec!["conv_out".into()], output: vec!["relu_out".into()], }, }, ); let analysis = AnalysisResult { original_model: None, model_type: "ONNX".into(), node_count: 2, initializer_count: 1, input_shape: vec![], output_shapes: vec![], output_names: vec![], opset_version: Some(13), nodes, initializer_names: HashSet::from(["w".into()]), }; let deps = get_segment_dependencies(&analysis, 0, 2); assert!(deps.output.contains(&"relu_out".to_string())); assert!(!deps.filtered_inputs.contains(&"w".to_string())); } fn make_attribute_graph( name: &str, graph: onnx_proto::GraphProto, ) -> onnx_proto::AttributeProto { onnx_proto::AttributeProto { name: name.to_string(), r#type: onnx_proto::onnx::attribute_proto::AttributeType::Graph as i32, g: Some(graph), ..Default::default() } } #[test] fn analyze_loop_captures_outer_scope_refs() { let relu = make_node("Relu", 0, vec!["x"], vec!["relu_out"]); let body_node = onnx_proto::NodeProto { op_type: "Add".into(), name: "body_add".into(), input: vec!["body_in".into(), "relu_out".into()], output: vec!["body_out".into()], attribute: vec![], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let body_input = onnx_proto::make_tensor_value_info("body_in", TensorProto::FLOAT, &[1, 3, 8, 8]); let body_cond_in = onnx_proto::make_tensor_value_info("cond_in", TensorProto::BOOL, &[]); let body_cond_out = onnx_proto::make_tensor_value_info("cond_out", TensorProto::BOOL, &[]); let body_output = onnx_proto::make_tensor_value_info("body_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let body_graph = onnx_proto::make_graph( "loop_body", vec![body_node], vec![body_cond_in.clone(), body_input], vec![body_cond_out, body_output], vec![], ); let loop_node = onnx_proto::NodeProto { op_type: "Loop".into(), name: "Loop_1".into(), input: vec!["trip_count".into(), "cond".into(), "init_val".into()], output: vec!["loop_out".into()], attribute: vec![make_attribute_graph("body", body_graph)], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let input = onnx_proto::make_tensor_value_info("x", TensorProto::FLOAT, &[1, 3, 8, 8]); let output = onnx_proto::make_tensor_value_info("loop_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let trip_vi = onnx_proto::make_tensor_value_info("trip_count", TensorProto::INT64, &[]); let cond_vi = onnx_proto::make_tensor_value_info("cond", TensorProto::BOOL, &[]); let init_vi = onnx_proto::make_tensor_value_info("init_val", TensorProto::FLOAT, &[1, 3, 8, 8]); let graph = onnx_proto::make_graph( "test", vec![relu, loop_node], vec![input, trip_vi, cond_vi, init_vi], vec![output], vec![], ); let model = onnx_proto::make_model(graph, 13); let result = analyze(&model, None).unwrap(); let loop_analysis = result .nodes .values() .find(|n| n.node_type == "Loop") .unwrap(); let loop_inputs = &loop_analysis.dependencies.input; assert!( loop_inputs.contains(&"relu_out".to_string()), "Loop node must include outer-scope ref 'relu_out' in its dependencies, got: {:?}", loop_inputs ); for local in &["body_in", "body_out", "cond_in", "cond_out"] { assert!( !loop_inputs.contains(&local.to_string()), "body-local name '{}' must not leak into Loop dependencies, got: {:?}", local, loop_inputs ); } } #[test] fn analyze_if_captures_outer_scope_refs() { let relu = make_node("Relu", 0, vec!["x"], vec!["relu_out"]); let then_node = onnx_proto::NodeProto { op_type: "Identity".into(), name: "then_id".into(), input: vec!["relu_out".into()], output: vec!["then_out".into()], attribute: vec![], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let then_output = onnx_proto::make_tensor_value_info("then_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let then_graph = onnx_proto::make_graph( "then_branch", vec![then_node], vec![], vec![then_output], vec![], ); let else_node = onnx_proto::NodeProto { op_type: "Neg".into(), name: "else_neg".into(), input: vec!["relu_out".into()], output: vec!["else_out".into()], attribute: vec![], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let else_output = onnx_proto::make_tensor_value_info("else_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let else_graph = onnx_proto::make_graph( "else_branch", vec![else_node], vec![], vec![else_output], vec![], ); let if_node = onnx_proto::NodeProto { op_type: "If".into(), name: "If_1".into(), input: vec!["cond".into()], output: vec!["if_out".into()], attribute: vec![ make_attribute_graph("then_branch", then_graph), make_attribute_graph("else_branch", else_graph), ], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let input = onnx_proto::make_tensor_value_info("x", TensorProto::FLOAT, &[1, 3, 8, 8]); let cond_vi = onnx_proto::make_tensor_value_info("cond", TensorProto::BOOL, &[]); let output = onnx_proto::make_tensor_value_info("if_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let graph = onnx_proto::make_graph( "test", vec![relu, if_node], vec![input, cond_vi], vec![output], vec![], ); let model = onnx_proto::make_model(graph, 13); let result = analyze(&model, None).unwrap(); let if_analysis = result.nodes.values().find(|n| n.node_type == "If").unwrap(); let if_inputs = &if_analysis.dependencies.input; assert!( if_inputs.contains(&"relu_out".to_string()), "If node must include outer-scope ref 'relu_out' from both branches, got: {:?}", if_inputs ); for local in &["then_out", "else_out"] { assert!( !if_inputs.contains(&local.to_string()), "branch-local name '{}' must not leak into If dependencies, got: {:?}", local, if_inputs ); } } #[test] fn segment_deps_include_subgraph_outer_refs() { let relu = make_node("Relu", 0, vec!["x"], vec!["relu_out"]); let body_node = onnx_proto::NodeProto { op_type: "Add".into(), name: "body_add".into(), input: vec!["body_in".into(), "relu_out".into()], output: vec!["body_out".into()], attribute: vec![], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let body_input = onnx_proto::make_tensor_value_info("body_in", TensorProto::FLOAT, &[1, 3, 8, 8]); let body_cond_in = onnx_proto::make_tensor_value_info("cond_in", TensorProto::BOOL, &[]); let body_cond_out = onnx_proto::make_tensor_value_info("cond_out", TensorProto::BOOL, &[]); let body_output = onnx_proto::make_tensor_value_info("body_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let body_graph = onnx_proto::make_graph( "loop_body", vec![body_node], vec![body_cond_in, body_input], vec![body_cond_out, body_output], vec![], ); let loop_node = onnx_proto::NodeProto { op_type: "Loop".into(), name: "Loop_1".into(), input: vec!["trip_count".into(), "cond".into(), "init_val".into()], output: vec!["loop_out".into()], attribute: vec![make_attribute_graph("body", body_graph)], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let input = onnx_proto::make_tensor_value_info("x", TensorProto::FLOAT, &[1, 3, 8, 8]); let output = onnx_proto::make_tensor_value_info("loop_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let trip_vi = onnx_proto::make_tensor_value_info("trip_count", TensorProto::INT64, &[]); let cond_vi = onnx_proto::make_tensor_value_info("cond", TensorProto::BOOL, &[]); let init_vi = onnx_proto::make_tensor_value_info("init_val", TensorProto::FLOAT, &[1, 3, 8, 8]); let graph = onnx_proto::make_graph( "test", vec![relu, loop_node], vec![input, trip_vi, cond_vi, init_vi], vec![output], vec![], ); let model = onnx_proto::make_model(graph, 13); let result = analyze(&model, None).unwrap(); let deps = get_segment_dependencies(&result, 1, 2); assert!( deps.input.contains(&"relu_out".to_string()), "segment containing only Loop must list 'relu_out' as input dep, got: {:?}", deps.input ); for local in &["body_in", "body_out", "cond_in", "cond_out"] { assert!( !deps.input.contains(&local.to_string()), "body-local name '{}' must not appear in segment inputs, got: {:?}", local, deps.input ); } } #[test] fn analyze_nested_subgraph_captures_outer_scope_refs() { let relu = make_node("Relu", 0, vec!["x"], vec!["relu_out"]); let inner_add = onnx_proto::NodeProto { op_type: "Add".into(), name: "inner_add".into(), input: vec!["inner_in".into(), "relu_out".into()], output: vec!["inner_out".into()], attribute: vec![], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let inner_input = onnx_proto::make_tensor_value_info("inner_in", TensorProto::FLOAT, &[1, 3, 8, 8]); let inner_output = onnx_proto::make_tensor_value_info("inner_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let inner_graph = onnx_proto::make_graph( "inner_then", vec![inner_add], vec![inner_input], vec![inner_output], vec![], ); let if_node_in_body = onnx_proto::NodeProto { op_type: "If".into(), name: "nested_if".into(), input: vec!["body_cond".into()], output: vec!["body_out".into()], attribute: vec![ make_attribute_graph("then_branch", inner_graph.clone()), make_attribute_graph("else_branch", inner_graph), ], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let body_cond_in = onnx_proto::make_tensor_value_info("cond_in", TensorProto::BOOL, &[]); let body_cond = onnx_proto::make_tensor_value_info("body_cond", TensorProto::BOOL, &[]); let body_cond_out = onnx_proto::make_tensor_value_info("cond_out", TensorProto::BOOL, &[]); let body_output = onnx_proto::make_tensor_value_info("body_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let body_graph = onnx_proto::make_graph( "loop_body", vec![if_node_in_body], vec![body_cond_in, body_cond], vec![body_cond_out, body_output], vec![], ); let loop_node = onnx_proto::NodeProto { op_type: "Loop".into(), name: "Loop_1".into(), input: vec!["trip_count".into(), "cond".into(), "init_val".into()], output: vec!["loop_out".into()], attribute: vec![make_attribute_graph("body", body_graph)], domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }; let input = onnx_proto::make_tensor_value_info("x", TensorProto::FLOAT, &[1, 3, 8, 8]); let output = onnx_proto::make_tensor_value_info("loop_out", TensorProto::FLOAT, &[1, 3, 8, 8]); let trip_vi = onnx_proto::make_tensor_value_info("trip_count", TensorProto::INT64, &[]); let cond_vi = onnx_proto::make_tensor_value_info("cond", TensorProto::BOOL, &[]); let init_vi = onnx_proto::make_tensor_value_info("init_val", TensorProto::FLOAT, &[1, 3, 8, 8]); let graph = onnx_proto::make_graph( "test", vec![relu, loop_node], vec![input, trip_vi, cond_vi, init_vi], vec![output], vec![], ); let model = onnx_proto::make_model(graph, 13); let result = analyze(&model, None).unwrap(); let loop_analysis = result .nodes .values() .find(|n| n.node_type == "Loop") .unwrap(); let nested_inputs = &loop_analysis.dependencies.input; assert!( nested_inputs.contains(&"relu_out".to_string()), "Loop with nested If subgraph referencing outer-scope 'relu_out' must capture it, got: {:?}", nested_inputs ); for local in &["body_cond", "inner_in", "inner_out", "body_out"] { assert!( !nested_inputs.contains(&local.to_string()), "nested-body-local name '{}' must not leak into Loop dependencies, got: {:?}", local, nested_inputs ); } } } ================================================ FILE: crates/dsperse/src/slicer/autotiler.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::Path; use super::onnx_proto::{self, GraphProto, ModelProto, NodeProto, TensorProto}; use crate::error::Result; use crate::schema::tiling::{ChannelGroupInfo, ChannelSplitInfo, DimSplitKind}; fn try_pair(v: &[i64]) -> Option<[i64; 2]> { if v.len() == 2 { Some([v[0], v[1]]) } else { None } } fn try_quad(v: &[i64]) -> Option<[i64; 4]> { if v.len() == 4 { Some([v[0], v[1], v[2], v[3]]) } else { None } } pub(crate) fn model_opset(model: &ModelProto) -> i64 { model .opset_import .iter() .filter(|o| o.domain.is_empty()) .map(|o| o.version) .max() .unwrap_or(13) } fn is_elementwise(op: &str) -> bool { super::is_elementwise(op) } #[derive(Debug, Clone)] pub struct ChannelSplitParams { pub c_in: i64, pub c_out: i64, pub num_groups: i64, pub channels_per_group: i64, pub h: i64, pub w: i64, pub slice_idx: usize, } struct PoolParams { node_idx: usize, kernel: [i64; 2], stride: [i64; 2], dilation: [i64; 2], pads: [i64; 4], } impl PoolParams { fn from_node(node: &NodeProto, node_idx: usize) -> Option { if node.op_type != "MaxPool" { return None; } let kernel = try_pair(&onnx_proto::get_attribute_ints(node, "kernel_shape")?)?; let stride = match onnx_proto::get_attribute_ints(node, "strides") { None => [1, 1], Some(v) => try_pair(&v)?, }; let dilation = match onnx_proto::get_attribute_ints(node, "dilations") { None => [1, 1], Some(v) => try_pair(&v)?, }; let auto_pad = node .attribute .iter() .find(|a| a.name == "auto_pad") .map(|a| a.s.as_slice()); if matches!(auto_pad, Some(v) if !v.is_empty() && v != b"NOTSET") { return None; } let pads = match onnx_proto::get_attribute_ints(node, "pads") { None => [0, 0, 0, 0], Some(v) => try_quad(&v)?, }; let ceil_mode = onnx_proto::get_attribute_int(node, "ceil_mode").unwrap_or(0); if ceil_mode != 0 { return None; } if kernel.iter().any(|&v| v <= 0) || stride.iter().any(|&v| v <= 0) { return None; } if dilation.iter().any(|&v| v <= 0) || pads.iter().any(|&v| v < 0) { return None; } Some(PoolParams { node_idx, kernel, stride, dilation, pads, }) } } fn get_pool_params(graph: &GraphProto) -> Option { for (idx, node) in graph.node.iter().enumerate() { if let Some(pp) = PoolParams::from_node(node, idx) { return Some(pp); } } None } struct ConvParams { node_idx: usize, kernel: [i64; 2], stride: [i64; 2], dilation: [i64; 2], pads: [i64; 4], group: i64, c_out: i64, c_in: i64, } impl ConvParams { fn from_node(node: &NodeProto, node_idx: usize, graph: &GraphProto) -> Option { if node.op_type != "Conv" { return None; } let w_name = node.input.get(1)?; let w = graph.initializer.iter().find(|t| &t.name == w_name)?; if w.dims.len() != 4 { return None; } let c_out = w.dims[0]; let c_in = w.dims[1]; if c_out <= 0 || c_in <= 0 { return None; } let inferred_kernel = [w.dims[2], w.dims[3]]; let kernel = match onnx_proto::get_attribute_ints(node, "kernel_shape") { Some(v) => { let k = try_pair(&v)?; if k != inferred_kernel { return None; } k } None => inferred_kernel, }; let stride = match onnx_proto::get_attribute_ints(node, "strides") { None => [1, 1], Some(v) => try_pair(&v)?, }; let dilation = match onnx_proto::get_attribute_ints(node, "dilations") { None => [1, 1], Some(v) => try_pair(&v)?, }; let auto_pad = node .attribute .iter() .find(|a| a.name == "auto_pad") .map(|a| a.s.as_slice()); if matches!(auto_pad, Some(v) if !v.is_empty() && v != b"NOTSET") { return None; } let pads = match onnx_proto::get_attribute_ints(node, "pads") { None => [0, 0, 0, 0], Some(v) => try_quad(&v)?, }; if kernel.iter().any(|&v| v <= 0) { return None; } if stride.iter().any(|&v| v <= 0) { return None; } if dilation.iter().any(|&v| v <= 0) { return None; } if pads.iter().any(|&v| v < 0) { return None; } let group = onnx_proto::get_attribute_int(node, "group").unwrap_or(1); if group <= 0 { return None; } Some(ConvParams { node_idx, kernel, stride, dilation, pads, group, c_out, c_in, }) } } fn get_conv_params(graph: &GraphProto) -> Option { for (idx, node) in graph.node.iter().enumerate() { if let Some(cp) = ConvParams::from_node(node, idx, graph) { return Some(cp); } } None } fn effective_kernel(kernel: [i64; 2], dilation: [i64; 2]) -> Option<[i64; 2]> { let ek0 = kernel[0] .checked_sub(1)? .checked_mul(dilation[0])? .checked_add(1)?; let ek1 = kernel[1] .checked_sub(1)? .checked_mul(dilation[1])? .checked_add(1)?; Some([ek0, ek1]) } fn conv_output_hw( h_in: i64, w_in: i64, pads: [i64; 4], kernel: [i64; 2], dilation: [i64; 2], stride: [i64; 2], ) -> Option<(i64, i64)> { if stride[0] <= 0 || stride[1] <= 0 { return None; } let eff = effective_kernel(kernel, dilation)?; let num_h = h_in .checked_add(pads[0])? .checked_add(pads[2])? .checked_sub(eff[0])?; let num_w = w_in .checked_add(pads[1])? .checked_add(pads[3])? .checked_sub(eff[1])?; let out_h = num_h.div_euclid(stride[0]).checked_add(1)?; let out_w = num_w.div_euclid(stride[1]).checked_add(1)?; if out_h <= 0 || out_w <= 0 { return None; } Some((out_h, out_w)) } fn compute_halo_size(pads: [i64; 4]) -> Option<[i64; 4]> { if pads.iter().any(|&v| v < 0) { return None; } Some(pads) } fn compute_min_spatial_tile(kernel: [i64; 2], dilation: [i64; 2]) -> Option { let eff = effective_kernel(kernel, dilation)?; eff[0].max(eff[1]).checked_add(1) } struct SpatialKernelParams { kernel: [i64; 2], stride: [i64; 2], dilation: [i64; 2], pads: [i64; 4], } fn extract_spatial_kernel_params( graph: &GraphProto, primary_op: &str, ) -> Option { if graph.input.len() > 1 { return None; } let op_count = graph .node .iter() .filter(|n| n.op_type == primary_op) .count(); if op_count != 1 { return None; } let (node_idx, kernel, stride, dilation, pads) = if primary_op == "Conv" { let cp = get_conv_params(graph)?; (cp.node_idx, cp.kernel, cp.stride, cp.dilation, cp.pads) } else if primary_op == "MaxPool" { let pp = get_pool_params(graph)?; (pp.node_idx, pp.kernel, pp.stride, pp.dilation, pp.pads) } else { return None; }; if node_idx != 0 { return None; } let ops: HashSet<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect(); if ops.iter().any(|&o| o != primary_op && !is_elementwise(o)) { return None; } Some(SpatialKernelParams { kernel, stride, dilation, pads, }) } fn is_spatial_tileable(graph: &GraphProto, primary_op: &str) -> bool { let Some(sp) = extract_spatial_kernel_params(graph, primary_op) else { return false; }; let Some(eff) = effective_kernel(sp.kernel, sp.dilation) else { return false; }; let total_pad_h = sp.pads[0] + sp.pads[2]; let total_pad_w = sp.pads[1] + sp.pads[3]; total_pad_h >= eff[0] - sp.stride[0] && total_pad_w >= eff[1] - sp.stride[1] } fn is_standard_conv_slice(graph: &GraphProto) -> Option { extract_spatial_kernel_params(graph, "Conv")?; get_conv_params(graph) } fn is_tileable(graph: &GraphProto) -> bool { is_spatial_tileable(graph, "Conv") } fn is_channel_splittable(graph: &GraphProto) -> bool { let Some(cp) = is_standard_conv_slice(graph) else { return false; }; cp.group == 1 } fn get_model_dimensions(graph: &GraphProto) -> Option<(String, String, i64, i64, i64)> { let inp = graph.input.first()?; let out = graph.output.first()?; let dims = onnx_proto::vi_shape(inp); if dims.len() != 4 || dims[1] <= 0 || dims[2] <= 0 || dims[3] <= 0 { return None; } Some(( inp.name.clone(), out.name.clone(), dims[1], dims[2], dims[3], )) } fn is_elementwise_only_slice(graph: &GraphProto) -> bool { if graph.node.is_empty() || graph.input.is_empty() { return false; } graph.node.iter().all(|n| is_elementwise(&n.op_type)) } fn find_weights_and_bias( graph: &GraphProto, conv_node: &NodeProto, ) -> (Option, Option>) { let mut weights: Option = None; let mut bias: Option> = None; for init in &graph.initializer { if conv_node.input.len() > 1 && init.name == conv_node.input[1] { let data = onnx_proto::tensor_to_f32(init); weights = Some(WeightInfo { data, dims: init.dims.clone(), }); } if conv_node.input.len() > 2 && init.name == conv_node.input[2] { bias = Some(onnx_proto::tensor_to_f32(init)); } } (weights, bias) } struct WeightInfo { data: Vec, dims: Vec, } struct SlicePrologue<'a> { graph: &'a GraphProto, cp: ConvParams, weights: Option, bias: Option>, } fn extract_slice_prologue(model: &ModelProto) -> Option> { let graph = model.graph.as_ref()?; let cp = get_conv_params(graph)?; let conv_node = &graph.node[cp.node_idx]; let (weights, bias) = find_weights_and_bias(graph, conv_node); if let Some(ref w) = weights { if w.dims.len() != 4 { return None; } let c_out = usize::try_from(w.dims[0]).ok()?; let c_in = usize::try_from(w.dims[1]).ok()?; let kh = usize::try_from(w.dims[2]).ok()?; let kw = usize::try_from(w.dims[3]).ok()?; let expected = c_out.checked_mul(c_in)?.checked_mul(kh)?.checked_mul(kw)?; if w.data.len() != expected { return None; } if let Some(ref b) = bias && b.len() != c_out { return None; } } Some(SlicePrologue { graph, cp, weights, bias, }) } fn find_optimal_tile_size( spatial_dim: i64, target: i64, min_tile: i64, stride: i64, ) -> Option { if min_tile <= target && target < spatial_dim { for tile in (min_tile..=target).rev() { if spatial_dim % tile == 0 && tile % stride == 0 { return Some(tile); } } } None } fn calculate_spatial_tile_config( channels: i64, h: i64, w: i64, tile_size: i64, min_tile: i64, stride: i64, ) -> (Option, Option<&'static str>) { let total = channels * h * w; if total <= tile_size { return (None, Some("already_fits")); } let max_tile = ((tile_size as f64) / (channels as f64)).sqrt() as i64; if max_tile < min_tile { return (None, Some("min_tile_too_large")); } let target_tile = max_tile.min(h).min(w); match find_optimal_tile_size(h, target_tile, min_tile, stride) { Some(t) => (Some(t), None), None => (None, Some("no_divisor")), } } fn calculate_channel_split_config( c_in: i64, _c_out: i64, h: i64, w: i64, tile_size: i64, ) -> Option<(i64, i64)> { if h == 0 || w == 0 { return None; } let max_ch = tile_size / (h * w); if max_ch >= 1 && max_ch < c_in { let mut num_groups = (c_in + max_ch - 1) / max_ch; if num_groups > 1 { let mut cpg = (c_in + num_groups - 1) / num_groups; while cpg * (num_groups - 1) >= c_in && num_groups > 1 { num_groups -= 1; cpg = (c_in + num_groups - 1) / num_groups; } if num_groups > 1 { return Some((num_groups, cpg)); } } } None } pub const CONV_TILE_BUDGET: i64 = 512; pub const POOL_TILE_BUDGET: i64 = 1024; pub fn detect_tiling_needs( model: &ModelProto, tile_size: Option, ) -> Option { let graph = model.graph.as_ref()?; tile_size?; let dims_4d = get_model_dimensions(graph); if let Some((ref inp_name, ref out_name, c_in, h, w)) = dims_4d && let Some(cp) = get_conv_params(graph) { let budget = CONV_TILE_BUDGET; let c_out = cp.c_out; if is_tileable(graph) { let min_tile = compute_min_spatial_tile(cp.kernel, cp.dilation)?; let (actual_tile, _skip_reason) = calculate_spatial_tile_config(c_in, h, w, budget, min_tile, cp.stride[0]); if let Some(actual_tile) = actual_tile && h % actual_tile == 0 && w % actual_tile == 0 && actual_tile % cp.stride[0] == 0 && actual_tile % cp.stride[1] == 0 { let tiles_y = h / actual_tile; let tiles_x = w / actual_tile; if tiles_y * tiles_x >= 2 { let halo = compute_halo_size(cp.pads)?; return Some(TilingDetection::Spatial { input_name: inp_name.clone(), output_name: out_name.clone(), input_names: vec![inp_name.clone()], ndim: 4, c_in, c_out, h, w, tile_size: actual_tile, halo, tiles_y, tiles_x, out_tile: [actual_tile / cp.stride[0], actual_tile / cp.stride[1]], stride: cp.stride, }); } } } if is_channel_splittable(graph) && let Some((num_groups, cpg)) = calculate_channel_split_config(c_in, c_out, h, w, budget) { return Some(TilingDetection::ChannelSplit { input_name: inp_name.clone(), output_name: out_name.clone(), c_in, c_out, h, w, num_groups, channels_per_group: cpg, }); } } if let Some((ref inp_name, ref out_name, c_in, h, w)) = dims_4d && is_spatial_tileable(graph, "MaxPool") && let Some(pp) = get_pool_params(graph) { let budget = POOL_TILE_BUDGET; let min_tile = compute_min_spatial_tile(pp.kernel, pp.dilation)?; let (actual_tile, _skip_reason) = calculate_spatial_tile_config(c_in, h, w, budget, min_tile, pp.stride[0]); if let Some(actual_tile) = actual_tile && h % actual_tile == 0 && w % actual_tile == 0 && actual_tile % pp.stride[0] == 0 && actual_tile % pp.stride[1] == 0 { let tiles_y = h / actual_tile; let tiles_x = w / actual_tile; if tiles_y * tiles_x >= 2 { let halo = compute_halo_size(pp.pads)?; return Some(TilingDetection::Spatial { input_name: inp_name.clone(), output_name: out_name.clone(), input_names: vec![inp_name.clone()], ndim: 4, c_in, c_out: c_in, h, w, tile_size: actual_tile, halo, tiles_y, tiles_x, out_tile: [actual_tile / pp.stride[0], actual_tile / pp.stride[1]], stride: pp.stride, }); } } } if let Some(detection) = detect_elementwise_fixed_segments(graph) { return Some(detection); } None } pub const ELEMENTWISE_SEGMENT_SIZE: i64 = 1024; fn elementwise_segment_size() -> i64 { std::env::var("DSPERSE_EW_SEGMENT_SIZE") .ok() .and_then(|v| v.parse::().ok()) .filter(|&v| v > 0) .unwrap_or(ELEMENTWISE_SEGMENT_SIZE) } fn detect_elementwise_fixed_segments(graph: &GraphProto) -> Option { if !is_elementwise_only_slice(graph) { return None; } let seg_size = elementwise_segment_size(); let out = graph.output.first()?; let first_inp = graph.input.first()?; let first_dims = onnx_proto::vi_shape(first_inp); if first_dims.is_empty() || first_dims.iter().any(|&d| d <= 0) { return None; } let total_elements = first_dims .iter() .try_fold(1i64, |acc, &d| acc.checked_mul(d))?; if total_elements <= seg_size { return None; } let last_dim = *first_dims.last().unwrap_or(&0); let mut effective_seg_size = seg_size; for init in &graph.initializer { let vol: i64 = init.dims.iter().product(); if vol <= 1 || vol == seg_size { continue; } if init.dims.len() == 1 && init.dims[0] == last_dim && last_dim > 0 { effective_seg_size = last_dim; continue; } return None; } let seg_size = effective_seg_size; let mut input_names = Vec::with_capacity(graph.input.len()); for inp in &graph.input { let d = onnx_proto::vi_shape(inp); if d != first_dims || d.iter().any(|&v| v <= 0) { return None; } input_names.push(inp.name.clone()); } #[allow(clippy::manual_div_ceil)] let num_segments = (total_elements + seg_size - 1) / seg_size; if num_segments < 2 { return None; } let primary_name = input_names[0].clone(); Some(TilingDetection::FixedSegment { input_name: primary_name, output_name: out.name.clone(), input_names, total_elements, segment_size: seg_size, num_segments, original_shape: first_dims, }) } pub const MAX_ESTIMATED_CONSTRAINTS: u64 = 750_000; /// Return the smallest divisor of `dim` that is >= `target`. Returns /// `None` if no such divisor exists in `(0, dim]`, which is the /// signal to refuse the dim-split: pad-then-trim on the last group /// would inject zeros into reductions on non-split axes (Softmax, /// LayerNorm, ReduceMean, etc.) and contaminate the unpadded /// region's outputs. fn smallest_divisor_at_least(dim: usize, target: usize) -> Option { if dim == 0 || target == 0 { return None; } let target = target.min(dim); (target..=dim).find(|&g| dim.is_multiple_of(g)) } #[derive(Debug, Clone)] pub struct DimSplitDetection { pub split_kind: DimSplitKind, pub split_dim: usize, pub dim_size: usize, pub num_groups: usize, pub elements_per_group: usize, pub input_name: String, pub output_name: String, pub concat_axis: usize, pub estimated_constraints: u64, pub weight_name: Option, pub k_dim: usize, pub n_dim: usize, pub k_chunks: usize, } pub fn estimate_slice_constraints(nodes: &[NodeProto], shapes: &HashMap>) -> u64 { let config = jstprove_circuits::api::EstimationConfig::bn254_defaults(); let mut total: u64 = 0; let to_usize_shape = |name: &String| -> Vec { shapes .get(name) .map(|s| s.iter().map(|&d| d.max(1) as usize).collect()) .unwrap_or_default() }; for node in nodes { let input_shapes: Vec> = node.input.iter().map(&to_usize_shape).collect(); let output_shapes: Vec> = node.output.iter().map(&to_usize_shape).collect(); let cost = jstprove_circuits::api::estimate_op_constraints( &node.op_type, &input_shapes, &output_shapes, &config, ); total = total.saturating_add(cost); } total } pub fn detect_dim_split( nodes: &[NodeProto], shapes: &HashMap>, initializer_names: &HashSet, model_opset: i64, ) -> Option { let estimated = estimate_slice_constraints(nodes, shapes); if estimated <= MAX_ESTIMATED_CONSTRAINTS { return None; } let target_groups = estimated.div_ceil(MAX_ESTIMATED_CONSTRAINTS) as usize; for (idx, node) in nodes.iter().enumerate() { if matches!(node.op_type.as_str(), "MatMul" | "Gemm") { // Gemm with a bias (input C) is not yet supported by the dim-split // template builder; skip so the template construction downstream // stays in sync with the detector. if node.op_type == "Gemm" && node.input.get(2).is_some_and(|s: &String| !s.is_empty()) { continue; } // The dim-split runner replaces the entire slice execution with // the patched MatMul template and only writes ds.output_name to // the tensor cache. If this MatMul/Gemm output is consumed by a // later node in the same slice, those downstream ops would never // execute and the slice would publish the wrong tensor. Decline // and let the search continue or fall through to other paths. let Some(node_out) = node.output.first().filter(|s| !s.is_empty()) else { continue; }; let consumed_downstream = nodes .iter() .skip(idx + 1) .any(|later| later.input.iter().any(|i| i == node_out)); if consumed_downstream { continue; } let Some(weight_name) = node.input.get(1) else { continue; }; if !initializer_names.contains(weight_name) { continue; } let Some(weight_shape) = shapes.get(weight_name) else { continue; }; if weight_shape.len() != 2 { continue; } // Gemm with transA=1 transposes the activation matrix, which the // single-row sequence tile and the rank-2 template do not model. // Skip so detection stays consistent with the template builder. if node.op_type == "Gemm" && super::onnx_proto::get_attribute_int(node, "transA").unwrap_or(0) == 1 { continue; } let trans_b = node.op_type == "Gemm" && super::onnx_proto::get_attribute_int(node, "transB").unwrap_or(0) == 1; let k_dim = if trans_b { weight_shape[1] as usize } else { weight_shape[0] as usize }; let n_dim = if trans_b { weight_shape[0] as usize } else { weight_shape[1] as usize }; let Some(inp_shape) = node.input.first().and_then(|name| shapes.get(name)) else { continue; }; let total_rows: usize = inp_shape .iter() .take(inp_shape.len().saturating_sub(1)) .map(|&d| d.max(1) as usize) .product(); if total_rows == 0 || k_dim == 0 || n_dim == 0 { continue; } let row_cost = k_dim.saturating_mul(n_dim).saturating_mul(2); let max_per_chunk = MAX_ESTIMATED_CONSTRAINTS as usize; // Even with k_chunks == k_dim (chunk_size == 1), the per-chunk // cost is at minimum n_dim * 2. If that alone exceeds the budget // the split is infeasible; let the caller fall through to other // detection paths. if n_dim.saturating_mul(2) > max_per_chunk { continue; } let mut k_chunks = if row_cost > max_per_chunk { row_cost.div_ceil(max_per_chunk).max(1) } else { 1 }; k_chunks = k_chunks.min(k_dim); while k_chunks < k_dim && k_dim .div_ceil(k_chunks) .saturating_mul(n_dim) .saturating_mul(2) > max_per_chunk { k_chunks += 1; } if total_rows == 1 && k_chunks == 1 { continue; } let Some(input_name) = node.input.first().filter(|s| !s.is_empty()).cloned() else { continue; }; let Some(output_name) = node.output.first().filter(|s| !s.is_empty()).cloned() else { continue; }; return Some(DimSplitDetection { split_kind: DimSplitKind::MatMulOutputDim, split_dim: 0, dim_size: total_rows, num_groups: total_rows, elements_per_group: 1, input_name, output_name, concat_axis: 0, estimated_constraints: estimated, weight_name: Some(weight_name.clone()), k_dim, n_dim, k_chunks, }); } } // Slice-boundary inputs are ones not produced by any node inside // this slice; everything else is internal data flow that the // dim-split rewrite cannot honour as a true split axis. let slice_internal_outputs: HashSet<&str> = nodes .iter() .flat_map(|n| n.output.iter()) .filter(|s| !s.is_empty()) .map(String::as_str) .collect(); for node in nodes { if node.op_type == "Softmax" { let Some(softmax_in) = node.input.first().and_then(|name| shapes.get(name)) else { continue; }; if softmax_in.len() != 4 { continue; } // ONNX Softmax default axis: opset >= 13 -> -1 // (last axis), opset < 13 -> 1 (channel axis). // unwrap_or(-1) silently mismatches runtime semantics on // opset <13 models that omit the attribute. let default_axis: i64 = if model_opset >= 13 { -1 } else { 1 }; let softmax_axis = onnx_proto::get_attribute_int(node, "axis").unwrap_or(default_axis); let softmax_axis_abs = if softmax_axis < 0 { (softmax_in.len() as i64 + softmax_axis).max(0) as usize } else { softmax_axis as usize }; // Find the attention-block input among the slice inputs: // the first slice-boundary tensor (external -- not the // output of any other node in this slice, and not an // initializer) whose rank matches the softmax input rank // (Q/V-like activation). let attn_input = nodes.iter().flat_map(|n| n.input.iter()).find(|name| { !name.is_empty() && !initializer_names.contains(name.as_str()) && !slice_internal_outputs.contains(name.as_str()) && shapes.get(*name).is_some_and(|s| s.len() == 4 && s[0] > 0) }); let Some(attn_input_name) = attn_input.cloned() else { continue; }; let Some(attn_shape) = shapes.get(&attn_input_name) else { continue; }; // Choose the dim (among 0..rank) that is not the softmax-reduction // axis and yields the highest axis size; that axis gives the // most groups and the lowest per-group cost. let mut best: Option<(usize, usize, DimSplitKind)> = None; for (d, &axis_len) in attn_shape.iter().enumerate() { if d == softmax_axis_abs { continue; } let dim_size = axis_len.max(1) as usize; if dim_size < 2 { continue; } let kind = if d == 1 { DimSplitKind::HeadDim } else { DimSplitKind::BatchDim }; let better = best.as_ref().is_none_or(|(_, sz, _)| dim_size > *sz); if better { best = Some((d, dim_size, kind)); } } let Some((split_dim, dim_size, split_kind)) = best else { continue; }; let num_groups = match smallest_divisor_at_least(dim_size, target_groups) { Some(g) => g, None => continue, }; let elements_per_group = dim_size / num_groups; let output_name = nodes .last() .and_then(|n| n.output.first()) .filter(|s| !s.is_empty()) .cloned() .unwrap_or_else(|| node.output.first().cloned().unwrap_or_default()); if output_name.is_empty() { continue; } // Reject the split when axis tracing through the slice cannot // prove the split axis lands at the same position (and size) // in the final output. Shape-reordering ops (Reshape, // Transpose, Flatten, Squeeze, Unsqueeze, Concat on the // split axis) are non-trivial to follow here, so we require // the output shape to match the attention input at split_dim. let Some(out_shape) = shapes.get(&output_name) else { continue; }; if out_shape.len() != attn_shape.len() || out_shape[split_dim] != attn_shape[split_dim] { continue; } return Some(DimSplitDetection { split_kind, split_dim, dim_size, num_groups, elements_per_group, input_name: attn_input_name, output_name, concat_axis: split_dim, estimated_constraints: estimated, weight_name: None, k_dim: 0, n_dim: 0, k_chunks: 1, }); } } let first_non_init_input = nodes.first().and_then(|n| { n.input .iter() .find(|name| !name.is_empty() && !initializer_names.contains(name.as_str())) }); let first_input_shape = first_non_init_input.and_then(|name| shapes.get(name))?; if first_input_shape.is_empty() { return None; } // Conv / ConvTranspose / Pooling are not separable along arbitrary // input axes: splitting the input channel or the spatial dimensions // produces semantically incorrect per-group outputs. The dedicated // detection paths (conv spatial tiling, channel splitting) handle // these ops correctly; this generic fallback refuses to emit a // split for them. MatMul / Gemm are *not* listed here: their // dedicated dim-split-k path handles the K-axis split when the // weight is an initializer, but non-terminal MatMul/Gemm slices or // slices whose weight is a runtime tensor still benefit from the // generic axis-0 (batch) fallback, which is always semantically // sound because the batch dimension is independent across rows. for node in nodes { if matches!( node.op_type.as_str(), "Conv" | "ConvTranspose" | "AveragePool" | "MaxPool" | "GlobalAveragePool" | "GlobalMaxPool" | "LRN" ) { return None; } } // Find the deepest split_dim that is still compatible with every // normalization-style op in the slice. Splitting a later axis produces // more groups and a smaller per-group cost without violating op semantics. let rank = first_input_shape.len(); // If the slice contains any axis-reordering op (Transpose) AND any // axis-sensitive normalization op (LayerNormalization / Softmax), // we can no longer cheaply trace which axis the normalization // really runs on after the reorder. Restrict the split to axis 0 // (always the batch dim, always semantically sound) so we never // emit a split that lands on the post-Transpose normalization axis. let has_transpose = nodes.iter().any(|n| n.op_type == "Transpose"); let has_norm = nodes.iter().any(|n| { matches!( n.op_type.as_str(), "LayerNormalization" | "Softmax" | "LogSoftmax" ) }); let mut max_allowed = if has_transpose && has_norm { 1 } else { rank }; for node in nodes { match node.op_type.as_str() { "LayerNormalization" => { let axis = onnx_proto::get_attribute_int(node, "axis").unwrap_or(-1); let resolved = if axis < 0 { (rank as i64 + axis).max(0) as usize } else { (axis as usize).min(rank) }; if resolved < max_allowed { max_allowed = resolved; } } "Softmax" | "LogSoftmax" => { // ONNX Softmax / LogSoftmax default axis: opset >= // 13 -> -1 (last axis), opset < 13 -> 1 (channel // axis). unwrap_or(-1) silently mismatches runtime // semantics on opset < 13 models that omit the // attribute. let default_axis: i64 = if model_opset >= 13 { -1 } else { 1 }; let axis = onnx_proto::get_attribute_int(node, "axis").unwrap_or(default_axis); let resolved = if axis < 0 { (rank as i64 + axis).max(0) as usize } else { (axis as usize).min(rank.saturating_sub(1)) }; if resolved < max_allowed { max_allowed = resolved; } } "BatchNormalization" => { // BatchNorm couples every spatial element to the // running mean / variance per channel; splitting any // axis would change those statistics. Force-reject // the dim-split unconditionally so the early return at // line 1044 fires regardless of prior state. max_allowed = 0; } _ => {} } } if max_allowed == 0 { return None; } let mut best: Option<(usize, usize)> = None; for (d, &axis_len) in first_input_shape.iter().enumerate().take(max_allowed) { let dim = axis_len.max(1) as usize; if dim <= 1 { continue; } if best.map(|(_, size)| dim > size).unwrap_or(true) { best = Some((d, dim)); } } let (split_dim, dim_size) = best?; let num_groups = smallest_divisor_at_least(dim_size, target_groups)?; let elements_per_group = dim_size / num_groups; let input_name = first_non_init_input.cloned()?; let output_name = nodes .last() .and_then(|n| n.output.first()) .filter(|s| !s.is_empty()) .cloned()?; // Require the final output shape to preserve rank and the split // axis size; otherwise an intermediate op (Reshape, Transpose, // Flatten, Squeeze, Unsqueeze) has reordered the axes and // concat_axis=split_dim would splice the groups into the wrong // output dimension. Tracing the axis through an arbitrary chain // of shape ops is out of scope here, so we conservatively reject. let out_shape = shapes.get(&output_name)?; if out_shape.len() != first_input_shape.len() || out_shape[split_dim] != first_input_shape[split_dim] { return None; } Some(DimSplitDetection { split_kind: DimSplitKind::BatchDim, split_dim, dim_size, num_groups, elements_per_group, input_name, output_name, concat_axis: split_dim, estimated_constraints: estimated, weight_name: None, k_dim: 0, n_dim: 0, k_chunks: 1, }) } #[derive(Debug, Clone)] pub enum TilingDetection { Spatial { input_name: String, output_name: String, input_names: Vec, ndim: i64, c_in: i64, c_out: i64, h: i64, w: i64, tile_size: i64, halo: [i64; 4], tiles_y: i64, tiles_x: i64, out_tile: [i64; 2], stride: [i64; 2], }, ChannelSplit { input_name: String, output_name: String, c_in: i64, c_out: i64, h: i64, w: i64, num_groups: i64, channels_per_group: i64, }, FixedSegment { input_name: String, output_name: String, input_names: Vec, total_elements: i64, segment_size: i64, num_segments: i64, original_shape: Vec, }, } struct SpatialTileGeometry { c_in: i64, c_out: i64, tile_h: i64, tile_w: i64, out_h: i64, out_w: i64, } fn compute_spatial_tile_geometry( graph: &GraphProto, pads: [i64; 4], kernel: [i64; 2], dilation: [i64; 2], stride: [i64; 2], tile_size: i64, c_out_override: Option, ) -> Result { let halo = compute_halo_size(pads).ok_or_else(|| { crate::error::DsperseError::Slicer("spatial tile: invalid pad values".to_string()) })?; let tile_h = tile_size .checked_add(halo[0]) .and_then(|v| v.checked_add(halo[2])) .ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "spatial tile: tile_h overflow (tile_size={tile_size}, halo={:?})", halo )) })?; let tile_w = tile_size .checked_add(halo[1]) .and_then(|v| v.checked_add(halo[3])) .ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "spatial tile: tile_w overflow (tile_size={tile_size}, halo={:?})", halo )) })?; let (out_h, out_w) = conv_output_hw(tile_h, tile_w, [0, 0, 0, 0], kernel, dilation, stride) .ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "spatial tile: invalid output dims for tile_h={tile_h}, tile_w={tile_w}, stride={stride:?}, kernel={kernel:?}" )) })?; let c_in = graph .input .first() .map(onnx_proto::vi_shape) .and_then(|s| (s.len() == 4 && s[1] > 0).then_some(s[1])) .ok_or_else(|| { crate::error::DsperseError::Slicer( "spatial tile: unable to determine input channels".to_string(), ) })?; let c_out = c_out_override.unwrap_or(c_in); Ok(SpatialTileGeometry { c_in, c_out, tile_h, tile_w, out_h, out_w, }) } struct TileModelSpec { nodes: Vec, input: onnx_proto::ValueInfoProto, output: onnx_proto::ValueInfoProto, initializers: Vec, out_hw: [i64; 2], } fn save_tile_model( model: &ModelProto, spec: TileModelSpec, slice_idx: usize, output_dir: &Path, ) -> Result { let graph = onnx_proto::make_graph( &format!("tile_{slice_idx}"), spec.nodes, vec![spec.input], vec![spec.output], spec.initializers, ); let tile_model = onnx_proto::make_model(graph, model_opset(model)); let tiles_dir = output_dir.join("tiles"); std::fs::create_dir_all(&tiles_dir) .map_err(|e| crate::error::DsperseError::io(e, &tiles_dir))?; let onnx_path = tiles_dir.join("tile.onnx"); onnx_proto::save_model(&tile_model, &onnx_path)?; Ok(TileSliceResult { path: format!("slice_{slice_idx}/payload/tiles/tile.onnx"), conv_out: spec.out_hw, }) } pub fn create_tile_slice( model: &ModelProto, tile_size: i64, slice_idx: usize, output_dir: &Path, ) -> Result { if tile_size <= 0 { return Err(crate::error::DsperseError::Slicer(format!( "create_tile_slice: tile_size must be > 0, got {tile_size}" ))); } let SlicePrologue { graph, cp, weights, bias, } = extract_slice_prologue(model).ok_or_else(|| { crate::error::DsperseError::Slicer( "create_tile_slice: failed to extract slice prologue".to_string(), ) })?; let conv_node = &graph.node[cp.node_idx]; let weights = weights.ok_or_else(|| { crate::error::DsperseError::Slicer("create_tile_slice: conv weights not found".to_string()) })?; let cfg_c_in = cp.c_in.checked_mul(cp.group).filter(|&v| v > 0); let geom = compute_spatial_tile_geometry( graph, cp.pads, cp.kernel, cp.dilation, cp.stride, tile_size, Some(weights.dims[0]), )?; if let Some(c) = cfg_c_in && geom.c_in != c { return Err(crate::error::DsperseError::Slicer(format!( "create_tile_slice: graph c_in ({}) != weight c_in*group ({c})", geom.c_in ))); } let x = onnx_proto::make_tensor_value_info( "tile_in", TensorProto::FLOAT, &[1, geom.c_in, geom.tile_h, geom.tile_w], ); let y = onnx_proto::make_tensor_value_info( "tile_out", TensorProto::FLOAT, &[1, geom.c_out, geom.out_h, geom.out_w], ); let mut initializers = vec![onnx_proto::make_tensor( "W", TensorProto::FLOAT, &weights.dims, weights.data, )]; let mut conv_inputs = vec!["tile_in".to_string(), "W".to_string()]; if let Some(bias_data) = &bias { let bias_dims = [geom.c_out]; initializers.push(onnx_proto::make_tensor( "B", TensorProto::FLOAT, &bias_dims, bias_data.clone(), )); conv_inputs.push("B".to_string()); } let mut conv_attrs = vec![ onnx_proto::make_attribute_ints("kernel_shape", &cp.kernel), onnx_proto::make_attribute_ints("strides", &cp.stride), onnx_proto::make_attribute_ints("pads", &[0, 0, 0, 0]), onnx_proto::make_attribute_ints("dilations", &cp.dilation), ]; if cp.group != 1 { conv_attrs.push(onnx_proto::make_attribute_int("group", cp.group)); } let mut nodes = vec![onnx_proto::make_node( "Conv", conv_inputs, vec!["conv_out".to_string()], conv_attrs, )]; integrate_extra_ops(graph, conv_node, &mut initializers, &mut nodes)?; save_tile_model( model, TileModelSpec { nodes, input: x, output: y, initializers, out_hw: [geom.out_h, geom.out_w], }, slice_idx, output_dir, ) } pub fn create_pool_tile_slice( model: &ModelProto, tile_size: i64, slice_idx: usize, output_dir: &Path, ) -> Result { if tile_size <= 0 { return Err(crate::error::DsperseError::Slicer(format!( "create_pool_tile_slice: tile_size must be > 0, got {tile_size}" ))); } let graph = model.graph.as_ref().ok_or_else(|| { crate::error::DsperseError::Slicer( "create_pool_tile_slice: model.graph is None".to_string(), ) })?; let pp = get_pool_params(graph).ok_or_else(|| { crate::error::DsperseError::Slicer( "create_pool_tile_slice: no MaxPool node found".to_string(), ) })?; let pool_node = &graph.node[pp.node_idx]; let geom = compute_spatial_tile_geometry( graph, pp.pads, pp.kernel, pp.dilation, pp.stride, tile_size, None, )?; let x = onnx_proto::make_tensor_value_info( "tile_in", TensorProto::FLOAT, &[1, geom.c_in, geom.tile_h, geom.tile_w], ); let y = onnx_proto::make_tensor_value_info( "tile_out", TensorProto::FLOAT, &[1, geom.c_out, geom.out_h, geom.out_w], ); let pool_attrs = vec![ onnx_proto::make_attribute_ints("kernel_shape", &pp.kernel), onnx_proto::make_attribute_ints("strides", &pp.stride), onnx_proto::make_attribute_ints("pads", &[0, 0, 0, 0]), onnx_proto::make_attribute_ints("dilations", &pp.dilation), ]; let mut nodes = vec![onnx_proto::make_node( "MaxPool", vec!["tile_in".to_string()], vec!["pool_out".to_string()], pool_attrs, )]; let mut initializers = Vec::new(); integrate_extra_ops(graph, pool_node, &mut initializers, &mut nodes)?; save_tile_model( model, TileModelSpec { nodes, input: x, output: y, initializers, out_hw: [geom.out_h, geom.out_w], }, slice_idx, output_dir, ) } fn integrate_extra_ops( graph: &GraphProto, primary_node: &NodeProto, initializers: &mut Vec, nodes: &mut Vec, ) -> crate::error::Result<()> { let primary_op = primary_node.op_type.as_str(); let orig_input_name = graph.input.first().map(|i| i.name.as_str()).unwrap_or(""); let extra: Vec<&NodeProto> = graph .node .iter() .filter(|n| n.op_type != primary_op) .collect(); if extra.is_empty() { let last = nodes.last_mut().ok_or_else(|| { crate::error::DsperseError::Slicer( "integrate_extra_ops: no nodes to set output on".into(), ) })?; let out = last.output.get_mut(0).ok_or_else(|| { crate::error::DsperseError::Slicer( "integrate_extra_ops: last node has no outputs".into(), ) })?; *out = "tile_out".to_string(); return Ok(()); } let mut primary_weight_names: HashSet = HashSet::new(); for inp in primary_node.input.iter().skip(1) { primary_weight_names.insert(inp.clone()); } for init in &graph.initializer { if !primary_weight_names.contains(&init.name) { initializers.push(init.clone()); } } let primary_outputs: HashSet = graph .node .iter() .filter(|n| n.op_type == primary_op) .flat_map(|n| n.output.iter().cloned()) .collect(); let primary_out_wire = nodes .last() .and_then(|n| n.output.first()) .cloned() .unwrap_or_else(|| format!("{}_out", primary_op.to_lowercase())); for (i, orig_node) in extra.iter().enumerate() { let new_inputs: Vec = orig_node .input .iter() .map(|inp| { if primary_outputs.contains(inp) { primary_out_wire.clone() } else if inp == orig_input_name { "tile_in".to_string() } else { inp.clone() } }) .collect(); let is_last = i == extra.len() - 1; let new_outputs = if is_last { vec!["tile_out".to_string()] } else { orig_node.output.clone() }; nodes.push(NodeProto { op_type: orig_node.op_type.clone(), input: new_inputs, output: new_outputs, attribute: orig_node.attribute.clone(), name: String::new(), domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }); } Ok(()) } #[allow(clippy::too_many_arguments)] fn create_channel_group_slice( model: &ModelProto, prologue: &SlicePrologue<'_>, group_idx: usize, c_start: i64, c_end: i64, h_in: i64, w_in: i64, slice_idx: usize, output_dir: &Path, ) -> Result { let cp = &prologue.cp; if c_start < 0 || c_end < 0 || c_start >= c_end { return Err(crate::error::DsperseError::Slicer(format!( "create_channel_group_slice: invalid channel range c_start={c_start}, c_end={c_end}" ))); } let weights = prologue.weights.as_ref().ok_or_else(|| { crate::error::DsperseError::Slicer( "create_channel_group_slice: conv weights not found".to_string(), ) })?; let c_group = c_end - c_start; let (h_out, w_out) = conv_output_hw(h_in, w_in, cp.pads, cp.kernel, cp.dilation, cp.stride) .ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "create_channel_group_slice: invalid output dims for h_in={h_in}, w_in={w_in}" )) })?; let c_out = cp.c_out; let input_name = format!("group_{group_idx}_in"); let output_name = format!("group_{group_idx}_out"); let x = onnx_proto::make_tensor_value_info( &input_name, TensorProto::FLOAT, &[1, c_group, h_in, w_in], ); let y = onnx_proto::make_tensor_value_info( &output_name, TensorProto::FLOAT, &[1, c_out, h_out, w_out], ); let c_start_uz = i64_to_usize(c_start, "create_channel_group_slice", "c_start")?; let c_end_uz = i64_to_usize(c_end, "create_channel_group_slice", "c_end")?; let sliced_weights = slice_weights(weights, c_start_uz, c_end_uz)?; let w_tensor = onnx_proto::make_tensor( "W", TensorProto::FLOAT, &sliced_weights.dims, sliced_weights.data, ); let mut conv_attrs = vec![ onnx_proto::make_attribute_ints("kernel_shape", &cp.kernel), onnx_proto::make_attribute_ints("strides", &cp.stride), onnx_proto::make_attribute_ints("pads", &cp.pads), onnx_proto::make_attribute_ints("dilations", &cp.dilation), ]; if cp.group != 1 { conv_attrs.push(onnx_proto::make_attribute_int("group", cp.group)); } let node = onnx_proto::make_node( "Conv", vec![input_name, "W".to_string()], vec![output_name], conv_attrs, ); let graph_proto = onnx_proto::make_graph( &format!("channel_group_{slice_idx}_{group_idx}"), vec![node], vec![x], vec![y], vec![w_tensor], ); let group_model = onnx_proto::make_model(graph_proto, model_opset(model)); let groups_dir = output_dir.join("channel_groups"); std::fs::create_dir_all(&groups_dir) .map_err(|e| crate::error::DsperseError::io(e, &groups_dir))?; let onnx_path = groups_dir.join(format!("group_{group_idx}.onnx")); onnx_proto::save_model(&group_model, &onnx_path)?; Ok(ChannelGroupInfo { group_idx, c_start: c_start_uz, c_end: c_end_uz, path: format!("slice_{slice_idx}/payload/channel_groups/group_{group_idx}.onnx"), jstprove_circuit_path: None, jstprove_settings_path: None, }) } fn i64_to_usize(val: i64, ctx: &str, name: &str) -> Result { usize::try_from(val).map_err(|_| { crate::error::DsperseError::Slicer(format!("{ctx}: {name} ({val}) out of range for usize")) }) } fn checked_dim_product(factors: &[usize]) -> Result { factors.iter().try_fold(1usize, |acc, &f| { acc.checked_mul(f).ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "slice_weights: dimension product overflow (factors={factors:?})" )) }) }) } fn slice_weights(weights: &WeightInfo, c_start: usize, c_end: usize) -> Result { if weights.dims.len() < 4 { return Err(crate::error::DsperseError::Slicer(format!( "slice_weights: expected >= 4 dims, got {}", weights.dims.len() ))); } let to_usize = |dim: i64, name: &str| -> Result { usize::try_from(dim).map_err(|_| { crate::error::DsperseError::Slicer(format!( "slice_weights: {name} dimension {dim} is negative or too large" )) }) }; let c_out = to_usize(weights.dims[0], "c_out")?; let c_in = to_usize(weights.dims[1], "c_in")?; let kh = to_usize(weights.dims[2], "kh")?; let kw = to_usize(weights.dims[3], "kw")?; let expected_len = checked_dim_product(&[c_out, c_in, kh, kw])?; if weights.data.len() != expected_len { return Err(crate::error::DsperseError::Slicer(format!( "slice_weights: data length {} != expected {} (dims={:?})", weights.data.len(), expected_len, weights.dims ))); } if c_start >= c_end { return Err(crate::error::DsperseError::Slicer(format!( "slice_weights: c_start ({c_start}) >= c_end ({c_end})" ))); } if c_end > c_in { return Err(crate::error::DsperseError::Slicer(format!( "slice_weights: c_end ({c_end}) exceeds c_in ({c_in})" ))); } let c_group = c_end - c_start; let capacity = checked_dim_product(&[c_out, c_group, kh, kw])?; let stride_cin = checked_dim_product(&[c_in, kh, kw])?; let stride_kh = checked_dim_product(&[kh, kw])?; let mut sliced = Vec::with_capacity(capacity); for o in 0..c_out { for c in c_start..c_end { for h in 0..kh { for w_idx in 0..kw { let idx = o * stride_cin + c * stride_kh + h * kw + w_idx; sliced.push(weights.data[idx]); } } } } Ok(WeightInfo { data: sliced, dims: vec![c_out as i64, c_group as i64, kh as i64, kw as i64], }) } fn save_conv_bias( prologue: &SlicePrologue<'_>, slice_idx: usize, output_dir: &Path, ) -> Result> { let Some(bias_data) = &prologue.bias else { return Ok(None); }; let groups_dir = output_dir.join("channel_groups"); std::fs::create_dir_all(&groups_dir) .map_err(|e| crate::error::DsperseError::io(e, &groups_dir))?; let bias_bytes = rmp_serde::to_vec_named(&bias_data)?; let bias_path = groups_dir.join("bias.msgpack"); std::fs::write(&bias_path, bias_bytes) .map_err(|e| crate::error::DsperseError::io(e, &bias_path))?; Ok(Some(format!( "slice_{slice_idx}/payload/channel_groups/bias.msgpack" ))) } #[allow(clippy::too_many_arguments)] pub fn apply_channel_splitting( model: &ModelProto, cfg: &ChannelSplitParams, input_name: &str, output_name: &str, output_dir: &Path, ) -> Result { let &ChannelSplitParams { c_in, c_out, num_groups, channels_per_group, h, w, slice_idx, } = cfg; if c_in <= 0 || c_out <= 0 || num_groups <= 0 || channels_per_group <= 0 || h <= 0 || w <= 0 { return Err(crate::error::DsperseError::Slicer(format!( "apply_channel_splitting: invalid ChannelSplitParams (c_in={c_in}, c_out={c_out}, num_groups={num_groups}, channels_per_group={channels_per_group}, h={h}, w={w})" ))); } let covered = num_groups.checked_mul(channels_per_group).ok_or_else(|| { crate::error::DsperseError::Slicer( "apply_channel_splitting: num_groups * channels_per_group overflow".to_string(), ) })?; if covered < c_in { return Err(crate::error::DsperseError::Slicer(format!( "apply_channel_splitting: cfg covers only {covered} input channels, expected at least {c_in}", ))); } let last_group_start = (num_groups - 1) .checked_mul(channels_per_group) .ok_or_else(|| { crate::error::DsperseError::Slicer( "apply_channel_splitting: group start computation overflow".to_string(), ) })?; if last_group_start >= c_in { return Err(crate::error::DsperseError::Slicer(format!( "apply_channel_splitting: cfg creates empty trailing groups (last_start={last_group_start}, c_in={c_in})" ))); } let prologue = extract_slice_prologue(model).ok_or_else(|| { crate::error::DsperseError::Slicer( "apply_channel_splitting: failed to extract slice prologue from model".to_string(), ) })?; let (_, _, model_c_in, model_h, model_w) = get_model_dimensions(prologue.graph).ok_or_else(|| { crate::error::DsperseError::Slicer( "apply_channel_splitting: unable to determine model dimensions".to_string(), ) })?; let model_c_out = prologue.cp.c_out; if prologue.cp.group != 1 { return Err(crate::error::DsperseError::Slicer(format!( "apply_channel_splitting: unsupported Conv group={}, expected 1", prologue.cp.group ))); } if prologue.cp.c_in != model_c_in { return Err(crate::error::DsperseError::Slicer(format!( "apply_channel_splitting: weight/model c_in mismatch (weights c_in={}, model c_in={})", prologue.cp.c_in, model_c_in ))); } if model_c_in != c_in || model_c_out != c_out || model_h != h || model_w != w { return Err(crate::error::DsperseError::Slicer(format!( "apply_channel_splitting: cfg dims (c_in={c_in}, c_out={c_out}, h={h}, w={w}) mismatch model dims (c_in={model_c_in}, c_out={model_c_out}, h={model_h}, w={model_w})" ))); } let (out_h, out_w) = conv_output_hw( h, w, prologue.cp.pads, prologue.cp.kernel, prologue.cp.dilation, prologue.cp.stride, ) .ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "apply_channel_splitting: invalid conv output dimensions for h={h}, w={w}, stride={:?}, kernel={:?}", prologue.cp.stride, prologue.cp.kernel )) })?; let groups_dir = output_dir.join("channel_groups"); let cleanup = || { if groups_dir.exists() { let _ = std::fs::remove_dir_all(&groups_dir); } }; let mut groups = Vec::new(); for g in 0..num_groups { let c_start = g * channels_per_group; let c_end = ((g + 1) * channels_per_group).min(c_in); let g_uz = i64_to_usize(g, "apply_channel_splitting", "group_idx").inspect_err(|_| { cleanup(); })?; let group_info = match create_channel_group_slice( model, &prologue, g_uz, c_start, c_end, h, w, slice_idx, output_dir, ) { Ok(info) => info, Err(e) => { cleanup(); return Err(e); } }; groups.push(group_info); } let bias_path = match save_conv_bias(&prologue, slice_idx, output_dir) { Ok(p) => p, Err(e) => { cleanup(); return Err(e); } }; let ctx = "apply_channel_splitting"; let c_in_uz = i64_to_usize(c_in, ctx, "c_in").inspect_err(|_| cleanup())?; let c_out_uz = i64_to_usize(c_out, ctx, "c_out").inspect_err(|_| cleanup())?; let num_groups_uz = i64_to_usize(num_groups, ctx, "num_groups").inspect_err(|_| cleanup())?; let cpg_uz = i64_to_usize(channels_per_group, ctx, "channels_per_group").inspect_err(|_| cleanup())?; let h_uz = i64_to_usize(h, ctx, "h").inspect_err(|_| cleanup())?; let w_uz = i64_to_usize(w, ctx, "w").inspect_err(|_| cleanup())?; let out_h_uz = i64_to_usize(out_h, ctx, "out_h").inspect_err(|_| cleanup())?; let out_w_uz = i64_to_usize(out_w, ctx, "out_w").inspect_err(|_| cleanup())?; Ok(ChannelSplitInfo { slice_idx, c_in: c_in_uz, c_out: c_out_uz, num_groups: num_groups_uz, channels_per_group: cpg_uz, input_name: input_name.to_string(), output_name: output_name.to_string(), h: h_uz, w: w_uz, out_h: out_h_uz, out_w: out_w_uz, groups, bias_path, }) } pub fn create_dim_split_template( model: &ModelProto, info: &crate::schema::tiling::DimSplitInfo, output_dir: &Path, traced_shapes: Option<&HashMap>>, ) -> Result { let graph = model.graph.as_ref().ok_or_else(|| { crate::error::DsperseError::Slicer("create_dim_split_template: model has no graph".into()) })?; match info.split_kind { crate::schema::tiling::DimSplitKind::MatMulOutputDim => { create_matmul_dim_template(model, graph, info, output_dir) } crate::schema::tiling::DimSplitKind::HeadDim | crate::schema::tiling::DimSplitKind::BatchDim => { create_generic_dim_template(model, graph, info, output_dir, traced_shapes) } } } fn create_matmul_dim_template( model: &ModelProto, graph: &GraphProto, info: &crate::schema::tiling::DimSplitInfo, output_dir: &Path, ) -> Result { let weight_name = info.weight_name.as_ref().ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "create_matmul_dim_template: slice {} DimSplitInfo missing weight_name", info.slice_idx )) })?; // Match the exact split node by weight, activation input, and output // name. A graph may reuse the same weight initializer in multiple // MatMul/Gemm ops (tied weights, weight sharing across heads); without // checking IO we could bind the wrong op and emit a template that // doesn't match the slice the runner will execute. let matmul_node = graph .node .iter() .find(|n| { matches!(n.op_type.as_str(), "MatMul" | "Gemm") && n.input.iter().any(|i| i == weight_name) && n.input.iter().any(|i| i == &info.input_name) && n.output.iter().any(|o| o == &info.output_name) }) .ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "create_matmul_dim_template: slice {} no MatMul/Gemm matches weight={weight_name:?} input={:?} output={:?}", info.slice_idx, info.input_name, info.output_name )) })?; if matmul_node.op_type == "Gemm" && matmul_node.input.get(2).is_some_and(|s| !s.is_empty()) { return Err(crate::error::DsperseError::Slicer(format!( "create_matmul_dim_template: slice {} Gemm with bias not supported for dim-split", info.slice_idx ))); } let weight_tensor = graph .initializer .iter() .find(|i| i.name == *weight_name) .ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "create_matmul_dim_template: weight {weight_name:?} not in initializers" )) })?; if weight_tensor.dims.len() != 2 { return Err(crate::error::DsperseError::Slicer(format!( "create_matmul_dim_template: expected 2D weights, got {:?}", weight_tensor.dims ))); } if matmul_node.op_type == "Gemm" && onnx_proto::get_attribute_int(matmul_node, "transA").unwrap_or(0) == 1 { return Err(crate::error::DsperseError::Slicer(format!( "create_matmul_dim_template: slice {} Gemm with transA=1 is not supported for dim-split", info.slice_idx ))); } let trans_b = matmul_node.op_type == "Gemm" && onnx_proto::get_attribute_int(matmul_node, "transB").unwrap_or(0) == 1; let (rows, cols) = ( weight_tensor.dims[0] as usize, weight_tensor.dims[1] as usize, ); let (k_dim, n_dim) = if trans_b { (cols, rows) } else { (rows, cols) }; let k_chunk_size = k_dim.div_ceil(info.k_chunks.max(1)); let tmpl_input_name = "dim_tmpl_in".to_string(); let tmpl_output_name = "dim_tmpl_out".to_string(); let tmpl_weight_name = "W".to_string(); let tmpl_input_shape: Vec = vec![1, k_chunk_size as i64]; let output_shape: Vec = vec![1, n_dim as i64]; let x = onnx_proto::make_tensor_value_info(&tmpl_input_name, TensorProto::FLOAT, &tmpl_input_shape); let y = onnx_proto::make_tensor_value_info(&tmpl_output_name, TensorProto::FLOAT, &output_shape); let tmpl_weight_dims: Vec = if trans_b { vec![n_dim as i64, k_chunk_size as i64] } else { vec![k_chunk_size as i64, n_dim as i64] }; let w = onnx_proto::make_tensor( &tmpl_weight_name, TensorProto::FLOAT, &tmpl_weight_dims, vec![0.0f32; k_chunk_size * n_dim], ); let mut attrs = Vec::new(); let node_inputs = vec![tmpl_input_name, tmpl_weight_name]; let initializers = vec![w]; if matmul_node.op_type == "Gemm" { if let Some(alpha) = onnx_proto::get_attribute_float(matmul_node, "alpha") { attrs.push(onnx_proto::make_attribute_float("alpha", alpha)); } if let Some(beta) = onnx_proto::get_attribute_float(matmul_node, "beta") { attrs.push(onnx_proto::make_attribute_float("beta", beta)); } // transA is rejected above; the template always uses A non-transposed. if trans_b { attrs.push(onnx_proto::make_attribute_int("transB", 1)); } // Biased Gemm is rejected above, so no C initializer is ever folded // into the template. } let node = onnx_proto::make_node( &matmul_node.op_type, node_inputs, vec![tmpl_output_name], attrs, ); let graph_proto = onnx_proto::make_graph( &format!("dim_template_{}", info.slice_idx), vec![node], vec![x], vec![y], initializers, ); let tmpl_model = onnx_proto::make_model(graph_proto, model_opset(model)); let tmpl_path = output_dir.join("dim_template.onnx"); onnx_proto::save_model(&tmpl_model, &tmpl_path)?; Ok(tmpl_path) } fn check_axis_separable( graph: &GraphProto, split_dim: usize, slice_idx: usize, model_opset: i64, ) -> Result<()> { let resolve_axis = |axis: i64| -> usize { let ndim = graph .input .first() .and_then(onnx_proto::shape_from_value_info) .map(|s| s.len() as i64) .unwrap_or(4); if axis < 0 { (ndim + axis) as usize } else { axis as usize } }; for node in &graph.node { match node.op_type.as_str() { "Flatten" => { let axis = resolve_axis(onnx_proto::get_attribute_int(node, "axis").unwrap_or(1)); if split_dim < axis { return Err(crate::error::DsperseError::Slicer(format!( "create_generic_dim_template: slice {slice_idx} Flatten axis \ {axis} > split_dim {split_dim}; split dimension falls in the merged leading group" ))); } } "Softmax" | "LogSoftmax" => { // ONNX Softmax / LogSoftmax default axis: opset >= // 13 -> -1 (last axis), opset < 13 -> 1 (channel // axis). let default_axis: i64 = if model_opset >= 13 { -1 } else { 1 }; let resolved = resolve_axis( onnx_proto::get_attribute_int(node, "axis").unwrap_or(default_axis), ); if resolved == split_dim { return Err(crate::error::DsperseError::Slicer(format!( "create_generic_dim_template: slice {slice_idx} {} axis {resolved} \ equals split_dim {split_dim}; normalization spans the split dimension", node.op_type ))); } } "LayerNormalization" => { let resolved = resolve_axis(onnx_proto::get_attribute_int(node, "axis").unwrap_or(-1)); if resolved <= split_dim { return Err(crate::error::DsperseError::Slicer(format!( "create_generic_dim_template: slice {slice_idx} LayerNormalization axis \ {resolved} <= split_dim {split_dim}; normalization spans the split dimension", ))); } } "BatchNormalization" if split_dim == 0 => { return Err(crate::error::DsperseError::Slicer(format!( "create_generic_dim_template: slice {slice_idx} BatchNormalization requires \ full batch statistics; cannot split at dim 0" ))); } _ => {} } } Ok(()) } fn create_generic_dim_template( model: &ModelProto, graph: &GraphProto, info: &crate::schema::tiling::DimSplitInfo, output_dir: &Path, traced_shapes: Option<&HashMap>>, ) -> Result { if info.elements_per_group == 0 { return Err(crate::error::DsperseError::Slicer(format!( "create_generic_dim_template: slice {} elements_per_group is 0", info.slice_idx ))); } check_axis_separable(graph, info.split_dim, info.slice_idx, model_opset(model))?; // Rewrite the template so the split axis carries elements_per_group // instead of the full dim_size. The runner only ever feeds a single // group's worth of activations to the compiled circuit, so the // *compile* cost should match the per-group cost rather than the // whole-slice cost. Catalog reuse is preserved at per-group // granularity: any two slices that share (split_dim, epg, surrounding // op shapes) hash identically. // // The strategy is: rewrite only the boundary shapes (graph inputs + // shape-input initializers consumed by Reshape / Expand / Tile / // ConstantOfShape) and a fresh shape inference pass derives every // intermediate value_info from those. Per-feature initializers // (gamma, beta, weights) are never touched, and there are no ad-hoc // cases for individual op patterns -- the rule is "rewrite the // boundary, let inference do the rest". let mut tmpl_model = model.clone(); let tmpl_graph = tmpl_model.graph.as_mut().ok_or_else(|| { crate::error::DsperseError::Slicer( "create_generic_dim_template: cloned model has no graph".into(), ) })?; let dim_size = info.dim_size as i64; let epg = info.elements_per_group as i64; let split_dim = info.split_dim; // 1. Decide which graph inputs must be rewritten at split_dim. // // The runner always slices every cached tensor whose shape has // dim_size at split_dim, so the *compile-time* template needs // every such input declared with epg, otherwise jstprove's // type checker rejects the op (e.g. Mul broadcast 150 vs 300, // or MatMul A.K vs B.K mismatch). But for ops where two // inputs reference dim_size at the *same* split_dim with // different semantic meanings (the canonical case is the // second attention MatMul: attn[B,H,M,N] @ V[B,H,N,D] with // M == N at split_dim=2) blanket rewriting both inputs // produces a real mismatch. // // Heuristic: // * Elementwise / broadcast ops (Add, Sub, Mul, Div, Pow, Min, // Max, Where, Equal, Greater, Less): rewrite every input // whose shape has dim_size at split_dim. All inputs share // a logical broadcast axis, so all must shrink together. // * MatMul / Gemm: rewrite only `info.input_name`. The other // operand's split_dim is a contraction axis; touching it // produces an inner-dim mismatch. // * Everything else (the single-op slices we get after // isolate_expensive_ops): rewrite only `info.input_name`, // which is the safe default for ops with one primary // activation and a handful of scalar / per-feature // initializer inputs. let elementwise_ops: HashSet<&str> = [ "Add", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Where", "Equal", "Greater", "Less", ] .into_iter() .collect(); let rewrite_all_matching = tmpl_graph .node .iter() .all(|n| elementwise_ops.contains(n.op_type.as_str())); let rewrite_input_at_split_dim = |vi: &mut super::onnx_proto::ValueInfoProto| { if let Some(t) = vi.r#type.as_mut() && let Some(super::onnx_proto::onnx::type_proto::Value::TensorType(tt)) = t.value.as_mut() && let Some(shape) = tt.shape.as_mut() && let Some(d) = shape.dim.get_mut(split_dim) && let Some(super::onnx_proto::onnx::tensor_shape_proto::dimension::Value::DimValue(v)) = d.value.as_mut() && *v == dim_size { *v = epg; } }; if rewrite_all_matching { for vi in tmpl_graph .input .iter_mut() .chain(tmpl_graph.output.iter_mut()) { rewrite_input_at_split_dim(vi); } } else { for vi in tmpl_graph .input .iter_mut() .filter(|vi| vi.name == info.input_name) .chain( tmpl_graph .output .iter_mut() .filter(|vi| vi.name == info.output_name), ) { rewrite_input_at_split_dim(vi); } } // 2. Rewrite shape-input initializers (Reshape / Expand / Tile / // ConstantOfShape). These are explicit shape descriptors; if // the input shape changes their dim_size entry must change too. let shape_input_initializers: HashSet = tmpl_graph .node .iter() .filter_map(|n| match n.op_type.as_str() { "Reshape" | "Expand" | "Tile" => n.input.get(1).cloned(), "ConstantOfShape" => n.input.first().cloned(), _ => None, }) .filter(|name| !name.is_empty()) .collect(); for init in &mut tmpl_graph.initializer { if init.data_type == TensorProto::INT64 && shape_input_initializers.contains(&init.name) { // ONNX TensorProto INT64 payloads can live in either // int64_data (typed field) or raw_data (little-endian // i64 byte stream); larger constants tend to use // raw_data. Patch both representations. for v in &mut init.int64_data { if *v == dim_size { *v = epg; } } if !init.raw_data.is_empty() && init.raw_data.len() % 8 == 0 { let mut buf: Vec = init .raw_data .chunks_exact(8) .map(|c| i64::from_le_bytes(c.try_into().unwrap())) .collect(); let mut changed = false; for v in &mut buf { if *v == dim_size { *v = epg; changed = true; } } if changed { let mut new_raw = Vec::with_capacity(buf.len() * 8); for v in &buf { new_raw.extend_from_slice(&v.to_le_bytes()); } init.raw_data = new_raw; } } } } // 3. Drop every intermediate value_info; it will be re-derived. tmpl_graph.value_info.clear(); let _ = traced_shapes; // intentionally unused: we re-trace after rewriting. let tmpl_path = output_dir.join("dim_template.onnx"); onnx_proto::save_model(&tmpl_model, &tmpl_path)?; // 4. Re-run shape inference on the rewritten template and inject // the derived shapes back as value_info. This replaces the old // ad-hoc per-op rewrites (which had to special-case every shape // op). If re-trace fails the template is uncompilable -- the // circuit compiler downstream will see no value_info for the // intermediate tensors and produce hard-to-diagnose shape // errors at compile time. Refuse to emit the template instead. let trace = super::trace::fold_and_trace_via_tract(&tmpl_path, &tmpl_model).map_err( |e| { crate::error::DsperseError::Slicer(format!( "create_generic_dim_template: slice {} re-trace failed (template input shape {:?}, split_dim {}): {e}", info.slice_idx, info.input_name, split_dim )) }, )?; { let mut model_after = onnx_proto::load_model(&tmpl_path)?; if let Some(graph_after) = model_after.graph.as_mut() { let existing: HashSet = graph_after .input .iter() .chain(graph_after.output.iter()) .chain(graph_after.value_info.iter()) .map(|vi| vi.name.clone()) .collect(); let init_names: HashSet<&str> = graph_after .initializer .iter() .map(|i| i.name.as_str()) .collect(); for node in &graph_after.node { for out_name in &node.output { if out_name.is_empty() || existing.contains(out_name) || init_names.contains(out_name.as_str()) { continue; } if let Some(shape) = trace.shapes.get(out_name) { let elem_type = trace .types .get(out_name) .copied() .unwrap_or(TensorProto::FLOAT); graph_after .value_info .push(onnx_proto::make_tensor_value_info( out_name, elem_type, shape, )); } } } // Promote output_name to graph output if it now exists in // value_info but not in graph.output. if !graph_after .output .iter() .any(|o| o.name == info.output_name) && let Some(vi) = graph_after .value_info .iter() .find(|v| v.name == info.output_name) .cloned() { graph_after.output.push(vi); } } onnx_proto::save_model(&model_after, &tmpl_path)?; } Ok(tmpl_path) } pub fn create_elementwise_tile_slice( model: &ModelProto, segment_size: i64, slice_idx: usize, output_dir: &Path, ) -> Result { if segment_size <= 0 { return Err(crate::error::DsperseError::Slicer(format!( "create_elementwise_tile_slice: segment_size must be > 0, got {segment_size}" ))); } let graph = model.graph.as_ref().ok_or_else(|| { crate::error::DsperseError::Slicer( "create_elementwise_tile_slice: model.graph is None".to_string(), ) })?; if graph.input.is_empty() { return Err(crate::error::DsperseError::Slicer( "create_elementwise_tile_slice: no graph inputs".to_string(), )); } let out = graph.output.first().ok_or_else(|| { crate::error::DsperseError::Slicer( "create_elementwise_tile_slice: no graph outputs".to_string(), ) })?; let orig_output_name = &out.name; let tile_shape: Vec = vec![segment_size]; let init_names: std::collections::HashSet<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect(); let mut orig_to_tile: Vec<(String, String)> = Vec::with_capacity(graph.input.len()); let mut tile_inputs = Vec::with_capacity(graph.input.len()); let mut tile_idx = 0usize; for inp in &graph.input { let inp_shape = onnx_proto::shape_from_value_info(inp); let is_broadcast = init_names.contains(inp.name.as_str()) || inp_shape .as_ref() .is_some_and(|s| s.iter().product::() < segment_size); if is_broadcast { tile_inputs.push(inp.clone()); } else { let tile_name = format!("tile_in_{tile_idx}"); tile_idx += 1; tile_inputs.push(onnx_proto::make_tensor_value_info( &tile_name, onnx_proto::elem_type_from_value_info(inp).unwrap_or(TensorProto::FLOAT), &tile_shape, )); orig_to_tile.push((inp.name.clone(), tile_name)); } } if tile_idx == 1 && let Some((_, tile_name)) = orig_to_tile.first_mut() { let old = tile_name.clone(); *tile_name = "tile_in".to_string(); for ti in &mut tile_inputs { if ti.name == old { ti.name = "tile_in".to_string(); } } } let y = onnx_proto::make_tensor_value_info("tile_out", TensorProto::FLOAT, &tile_shape); let initializers: Vec<_> = graph.initializer.to_vec(); let input_remap: std::collections::HashMap<&str, &str> = orig_to_tile .iter() .map(|(k, v)| (k.as_str(), v.as_str())) .collect(); let mut nodes = Vec::new(); for orig_node in &graph.node { let new_inputs: Vec = orig_node .input .iter() .map(|name| { input_remap .get(name.as_str()) .map(|s| (*s).to_string()) .unwrap_or_else(|| name.clone()) }) .collect(); let produces_output = orig_node.output.contains(orig_output_name); let new_outputs = if produces_output { orig_node .output .iter() .map(|o| { if o == orig_output_name { "tile_out".to_string() } else { o.clone() } }) .collect() } else { orig_node.output.clone() }; nodes.push(NodeProto { op_type: orig_node.op_type.clone(), input: new_inputs, output: new_outputs, attribute: orig_node.attribute.clone(), name: String::new(), domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], }); } let tile_graph = onnx_proto::make_graph( &format!("tile_{slice_idx}"), nodes, tile_inputs, vec![y], initializers, ); let tile_model = onnx_proto::make_model(tile_graph, model_opset(model)); let tiles_dir = output_dir.join("tiles"); std::fs::create_dir_all(&tiles_dir) .map_err(|e| crate::error::DsperseError::io(e, &tiles_dir))?; let onnx_path = tiles_dir.join("tile.onnx"); onnx_proto::save_model(&tile_model, &onnx_path)?; Ok(TileSliceResult { path: format!("slice_{slice_idx}/payload/tiles/tile.onnx"), conv_out: [segment_size, 1], }) } #[derive(Debug)] pub struct TileSliceResult { pub path: String, pub conv_out: [i64; 2], } #[cfg(test)] mod tests { use super::*; #[test] fn halo_symmetric_pads() { assert_eq!(compute_halo_size([1, 1, 1, 1]), Some([1, 1, 1, 1])); } #[test] fn halo_asymmetric_pads() { assert_eq!(compute_halo_size([6, 6, 7, 7]), Some([6, 6, 7, 7])); } #[test] fn halo_zero_pads() { assert_eq!(compute_halo_size([0, 0, 0, 0]), Some([0, 0, 0, 0])); } #[test] fn halo_negative_pads_rejected() { assert_eq!(compute_halo_size([-1, 0, 0, 0]), None); } #[test] fn halo_mixed_pads() { assert_eq!(compute_halo_size([1, 2, 1, 2]), Some([1, 2, 1, 2])); } #[test] fn min_tile_3x3_no_dilation() { assert_eq!(compute_min_spatial_tile([3, 3], [1, 1]), Some(4)); } #[test] fn min_tile_5x5_no_dilation() { assert_eq!(compute_min_spatial_tile([5, 5], [1, 1]), Some(6)); } #[test] fn min_tile_3x3_dilation_2() { let eff = (3 - 1) * 2 + 1; assert_eq!(compute_min_spatial_tile([3, 3], [2, 2]), Some(eff + 1)); } #[test] fn min_tile_1x1() { assert_eq!(compute_min_spatial_tile([1, 1], [1, 1]), Some(2)); } #[test] fn optimal_tile_exact_divisor() { assert_eq!(find_optimal_tile_size(64, 32, 4, 1), Some(32)); } #[test] fn optimal_tile_no_exact_divisor_falls_back() { assert_eq!(find_optimal_tile_size(64, 30, 4, 1), Some(16)); } #[test] fn optimal_tile_target_equals_spatial() { assert_eq!(find_optimal_tile_size(32, 32, 4, 1), None); } #[test] fn optimal_tile_min_exceeds_target() { assert_eq!(find_optimal_tile_size(64, 3, 4, 1), None); } #[test] fn optimal_tile_stride_constraint() { assert_eq!(find_optimal_tile_size(64, 32, 4, 2), Some(32)); assert_eq!(find_optimal_tile_size(12, 8, 2, 4), Some(4)); } #[test] fn optimal_tile_no_valid_stride_divisor() { assert_eq!(find_optimal_tile_size(15, 10, 2, 4), None); } #[test] fn checked_dim_product_normal() { assert_eq!(checked_dim_product(&[2, 3, 4]).unwrap(), 24); } #[test] fn checked_dim_product_empty() { assert_eq!(checked_dim_product(&[]).unwrap(), 1); } #[test] fn checked_dim_product_overflow() { assert!(checked_dim_product(&[usize::MAX, 2]).is_err()); } #[test] fn checked_dim_product_single() { assert_eq!(checked_dim_product(&[42]).unwrap(), 42); } #[test] fn slice_weights_basic() { let weights = WeightInfo { data: (0..24).map(|i| i as f32).collect(), dims: vec![2, 3, 2, 2], }; let sliced = slice_weights(&weights, 0, 2).unwrap(); assert_eq!(sliced.dims, vec![2, 2, 2, 2]); assert_eq!(sliced.data.len(), 16); assert_eq!(sliced.data[0], 0.0); assert_eq!(sliced.data[1], 1.0); assert_eq!(sliced.data[2], 2.0); assert_eq!(sliced.data[3], 3.0); } #[test] fn slice_weights_single_channel() { let weights = WeightInfo { data: (0..24).map(|i| i as f32).collect(), dims: vec![2, 3, 2, 2], }; let sliced = slice_weights(&weights, 1, 2).unwrap(); assert_eq!(sliced.dims, vec![2, 1, 2, 2]); assert_eq!(sliced.data.len(), 8); } #[test] fn slice_weights_start_ge_end() { let weights = WeightInfo { data: vec![1.0; 16], dims: vec![1, 4, 2, 2], }; assert!(slice_weights(&weights, 3, 2).is_err()); } #[test] fn slice_weights_end_exceeds_c_in() { let weights = WeightInfo { data: vec![1.0; 16], dims: vec![1, 4, 2, 2], }; assert!(slice_weights(&weights, 0, 5).is_err()); } #[test] fn slice_weights_insufficient_dims() { let weights = WeightInfo { data: vec![1.0; 6], dims: vec![2, 3], }; assert!(slice_weights(&weights, 0, 1).is_err()); } #[test] fn slice_weights_data_length_mismatch() { let weights = WeightInfo { data: vec![1.0; 10], dims: vec![2, 3, 2, 2], }; assert!(slice_weights(&weights, 0, 2).is_err()); } #[test] fn elementwise_ops_recognized() { assert!(is_elementwise("Relu")); assert!(is_elementwise("Sigmoid")); assert!(is_elementwise("Add")); assert!(is_elementwise("Mul")); } #[test] fn non_elementwise_ops_rejected() { assert!(!is_elementwise("Conv")); assert!(!is_elementwise("MaxPool")); assert!(!is_elementwise("Gemm")); assert!(!is_elementwise("BatchNormalization")); } #[test] fn spatial_tile_config_already_fits() { let (tile, reason) = calculate_spatial_tile_config(3, 4, 4, 64, 4, 1); assert!(tile.is_none()); assert_eq!(reason, Some("already_fits")); } #[test] fn spatial_tile_config_min_tile_too_large() { let (tile, reason) = calculate_spatial_tile_config(64, 8, 8, 100, 8, 1); assert!(tile.is_none()); assert_eq!(reason, Some("min_tile_too_large")); } #[test] fn spatial_tile_config_finds_tile() { let (tile, reason) = calculate_spatial_tile_config(3, 64, 64, 3 * 32 * 32, 4, 1); assert!(tile.is_some()); assert!(reason.is_none()); let t = tile.unwrap(); assert!(64 % t == 0); assert!(t >= 4); } #[test] fn channel_split_config_basic() { let result = calculate_channel_split_config(64, 32, 4, 4, 32); assert!(result.is_some()); let (num_groups, cpg) = result.unwrap(); assert!(num_groups > 1); assert!(cpg > 0); assert!(cpg * (num_groups - 1) < 64); } #[test] fn channel_split_config_zero_dims() { assert!(calculate_channel_split_config(64, 32, 0, 4, 32).is_none()); assert!(calculate_channel_split_config(64, 32, 4, 0, 32).is_none()); } #[test] fn channel_split_config_fits_without_splitting() { assert!(calculate_channel_split_config(4, 32, 2, 2, 100).is_none()); } #[test] fn detect_tiling_none_without_tile_size() { let model = onnx_proto::make_model( onnx_proto::make_graph("test", vec![], vec![], vec![], vec![]), 13, ); assert!(detect_tiling_needs(&model, None).is_none()); } #[test] fn detect_tiling_none_empty_graph() { let model = onnx_proto::make_model( onnx_proto::make_graph("test", vec![], vec![], vec![], vec![]), 13, ); assert!(detect_tiling_needs(&model, Some(1024)).is_none()); } #[test] fn effective_kernel_overflow() { assert_eq!(effective_kernel([i64::MAX, 1], [2, 1]), None); assert_eq!(effective_kernel([1, i64::MAX], [1, 2]), None); } #[test] fn effective_kernel_sub_underflow() { assert_eq!(effective_kernel([i64::MIN, 3], [1, 1]), None); } #[test] fn effective_kernel_valid() { assert_eq!(effective_kernel([3, 3], [1, 1]), Some([3, 3])); assert_eq!(effective_kernel([3, 3], [2, 2]), Some([5, 5])); assert_eq!(effective_kernel([1, 1], [1, 1]), Some([1, 1])); } #[test] fn conv_output_hw_zero_stride() { assert_eq!( conv_output_hw(8, 8, [0, 0, 0, 0], [3, 3], [1, 1], [0, 1]), None ); assert_eq!( conv_output_hw(8, 8, [0, 0, 0, 0], [3, 3], [1, 1], [1, 0]), None ); } #[test] fn conv_output_hw_kernel_exceeds_input() { assert_eq!( conv_output_hw(2, 2, [0, 0, 0, 0], [5, 5], [1, 1], [1, 1]), None ); } #[test] fn conv_output_hw_overflow_pads() { assert_eq!( conv_output_hw(i64::MAX, 8, [1, 0, 0, 0], [3, 3], [1, 1], [1, 1]), None ); } #[test] fn conv_output_hw_valid() { assert_eq!( conv_output_hw(8, 8, [1, 1, 1, 1], [3, 3], [1, 1], [1, 1]), Some((8, 8)) ); assert_eq!( conv_output_hw(8, 8, [0, 0, 0, 0], [3, 3], [1, 1], [2, 2]), Some((3, 3)) ); } #[test] fn compute_halo_size_negative_rejected() { assert_eq!(compute_halo_size([0, 0, -1, 0]), None); } #[test] fn compute_min_spatial_tile_overflow() { assert_eq!(compute_min_spatial_tile([i64::MAX, 1], [2, 1]), None); } #[test] fn slice_weights_full_range_is_identity() { let data: Vec = (0..48).map(|i| i as f32).collect(); let weights = WeightInfo { data: data.clone(), dims: vec![2, 3, 2, 4], }; let sliced = slice_weights(&weights, 0, 3).unwrap(); assert_eq!(sliced.dims, vec![2, 3, 2, 4]); assert_eq!(sliced.data, data); } #[test] fn detect_dim_split_gemm_trans_b() { use super::onnx_proto::{NodeProto, make_attribute_int}; // Unbiased Gemm with transB=1. Biased Gemm is rejected upstream by // create_matmul_dim_template, so the detector now skips it as well. let node = NodeProto { op_type: "Gemm".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["output".to_string()], attribute: vec![make_attribute_int("transB", 1)], ..Default::default() }; let mut shapes = HashMap::new(); shapes.insert("input".to_string(), vec![4, 145, 384]); shapes.insert("weight".to_string(), vec![1536, 384]); shapes.insert("output".to_string(), vec![4, 145, 1536]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); let detection = detect_dim_split(&[node], &shapes, &init_names, 17); assert!(detection.is_some()); let d = detection.unwrap(); assert_eq!(d.split_dim, 0); assert_eq!(d.dim_size, 580); assert_eq!(d.num_groups, 580); assert_eq!(d.elements_per_group, 1); assert_eq!(d.k_dim, 384); assert_eq!(d.n_dim, 1536); assert!(matches!(d.split_kind, DimSplitKind::MatMulOutputDim)); } #[test] fn detect_dim_split_matmul_no_trans() { let node = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut shapes = HashMap::new(); shapes.insert("input".to_string(), vec![4, 145, 384]); shapes.insert("weight".to_string(), vec![384, 1536]); shapes.insert("output".to_string(), vec![4, 145, 1536]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); let detection = detect_dim_split(&[node], &shapes, &init_names, 17); assert!(detection.is_some()); let d = detection.unwrap(); assert_eq!(d.split_dim, 0); assert_eq!(d.dim_size, 580); assert_eq!(d.num_groups, 580); assert_eq!(d.elements_per_group, 1); assert_eq!(d.k_dim, 384); assert_eq!(d.n_dim, 1536); assert!(matches!(d.split_kind, DimSplitKind::MatMulOutputDim)); } #[test] fn detect_dim_split_k_chunks_saturate_budget() { // k_dim=10, n_dim=300_000: row_cost=6M. Naive k_chunks=ceil(6M/2M)=3 // yields chunk_size=ceil(10/3)=4 -> per-chunk=4*300_000*2=2.4M > 2M // (MAX_ESTIMATED_CONSTRAINTS). Loop bumps k_chunks to 4 giving // chunk_size=3 -> per-chunk=1.8M which fits. let node = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut shapes = HashMap::new(); shapes.insert("input".to_string(), vec![4, 10]); shapes.insert("weight".to_string(), vec![10, 300_000]); shapes.insert("output".to_string(), vec![4, 300_000]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); let d = detect_dim_split(&[node], &shapes, &init_names, 17).unwrap(); assert_eq!(d.k_dim, 10); assert_eq!(d.n_dim, 300_000); let chunk_size = d.k_dim.div_ceil(d.k_chunks); assert!( chunk_size * d.n_dim * 2 <= MAX_ESTIMATED_CONSTRAINTS as usize, "per-chunk cost {} exceeds MAX {}", chunk_size * d.n_dim * 2, MAX_ESTIMATED_CONSTRAINTS ); } #[test] fn detect_dim_split_single_row_with_k_chunking() { // total_rows=1 but k*n*2 > MAX: still detect, K-chunk it. let node = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut shapes = HashMap::new(); shapes.insert("input".to_string(), vec![1, 2048]); shapes.insert("weight".to_string(), vec![2048, 2048]); shapes.insert("output".to_string(), vec![1, 2048]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); let d = detect_dim_split(&[node], &shapes, &init_names, 17).unwrap(); assert_eq!(d.dim_size, 1); assert_eq!(d.num_groups, 1); assert!(d.k_chunks > 1, "expected K-chunking for single row"); let chunk_size = d.k_dim.div_ceil(d.k_chunks); assert!(chunk_size * d.n_dim * 2 <= MAX_ESTIMATED_CONSTRAINTS as usize); } #[test] fn detect_dim_split_skips_single_row_single_chunk() { // total_rows=1 and k*n*2 <= MAX: nothing to split via MatMul path. // The slice is still over budget (forced via a second MatMul), but // dim-split should decline and let the caller fall through. let node1 = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "w1".to_string()], output: vec!["mid".to_string()], ..Default::default() }; let node2 = NodeProto { op_type: "MatMul".to_string(), input: vec!["mid".to_string(), "w2".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut shapes = HashMap::new(); shapes.insert("input".to_string(), vec![1, 64]); shapes.insert("w1".to_string(), vec![64, 64]); shapes.insert("mid".to_string(), vec![1, 64]); shapes.insert("w2".to_string(), vec![64, 64]); shapes.insert("output".to_string(), vec![1, 64]); let mut init_names = HashSet::new(); init_names.insert("w1".to_string()); init_names.insert("w2".to_string()); // Tiny per-op cost; slice estimate stays under MAX so detect_dim_split // returns None at the outer gate, which is what we want for a // single-row single-chunk MatMul. assert!(detect_dim_split(&[node1, node2], &shapes, &init_names, 17).is_none()); } #[test] fn detect_dim_split_declines_infeasible_n() { // n_dim * 2 > MAX means even k_chunks == k_dim (chunk_size = 1) // cannot fit inside the per-chunk budget, so the MatMul branch must // decline. Use batch=1 so the BatchDim fallback path is not taken. let node = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut shapes = HashMap::new(); // n_dim = 1_500_000 -> n*2 = 3_000_000 > MAX (2_000_000) shapes.insert("input".to_string(), vec![1, 4]); shapes.insert("weight".to_string(), vec![4, 1_500_000]); shapes.insert("output".to_string(), vec![1, 1_500_000]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); let got = detect_dim_split(&[node], &shapes, &init_names, 17); assert!( got.as_ref() .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)), "expected MatMul dim-split to decline, got {got:?}" ); } #[test] fn detect_dim_split_skips_non_terminal_matmul() { // MatMul output is consumed by a later Add inside the same slice. // The dim-split runner only writes MatMul output to the cache, so // the Add would never run; detection must decline this MatMul and // either pick a later terminal MatMul or fall through. let matmul = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["mid".to_string()], ..Default::default() }; let add = NodeProto { op_type: "Add".to_string(), input: vec!["mid".to_string(), "bias".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut shapes = HashMap::new(); shapes.insert("input".to_string(), vec![1, 145, 384]); shapes.insert("weight".to_string(), vec![384, 1536]); shapes.insert("bias".to_string(), vec![1536]); shapes.insert("mid".to_string(), vec![1, 145, 1536]); shapes.insert("output".to_string(), vec![1, 145, 1536]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); init_names.insert("bias".to_string()); let got = detect_dim_split(&[matmul, add], &shapes, &init_names, 17); assert!( got.as_ref() .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)), "expected non-terminal MatMul to be declined, got {got:?}" ); } #[test] fn detect_dim_split_picks_terminal_matmul_after_consumed_one() { // First MatMul feeds a second MatMul; only the second is terminal, // so detection must skip the first and select the second when both // are otherwise eligible. let m1 = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "w1".to_string()], output: vec!["mid".to_string()], ..Default::default() }; let m2 = NodeProto { op_type: "MatMul".to_string(), input: vec!["mid".to_string(), "w2".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut shapes = HashMap::new(); shapes.insert("input".to_string(), vec![4, 145, 384]); shapes.insert("w1".to_string(), vec![384, 1536]); shapes.insert("mid".to_string(), vec![4, 145, 1536]); shapes.insert("w2".to_string(), vec![1536, 384]); shapes.insert("output".to_string(), vec![4, 145, 384]); let mut init_names = HashSet::new(); init_names.insert("w1".to_string()); init_names.insert("w2".to_string()); let d = detect_dim_split(&[m1, m2], &shapes, &init_names, 17).unwrap(); assert_eq!(d.weight_name.as_deref(), Some("w2")); assert_eq!(d.output_name, "output"); assert_eq!(d.k_dim, 1536); assert_eq!(d.n_dim, 384); } #[test] fn detect_dim_split_skips_gemm_trans_a() { use super::onnx_proto::make_attribute_int; let node = NodeProto { op_type: "Gemm".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["output".to_string()], attribute: vec![make_attribute_int("transA", 1)], ..Default::default() }; let mut shapes = HashMap::new(); // Use batch=1 so the BatchDim fallback path does not mask the // MatMul-branch decline we want to assert. shapes.insert("input".to_string(), vec![1, 384, 145]); shapes.insert("weight".to_string(), vec![384, 1536]); shapes.insert("output".to_string(), vec![1, 145, 1536]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); let got = detect_dim_split(&[node], &shapes, &init_names, 17); assert!( got.as_ref() .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)), "expected Gemm transA=1 MatMul decline, got {got:?}" ); } #[test] fn detect_dim_split_skips_gemm_with_bias() { use super::onnx_proto::make_attribute_int; let node = NodeProto { op_type: "Gemm".to_string(), input: vec![ "input".to_string(), "weight".to_string(), "bias".to_string(), ], output: vec!["output".to_string()], attribute: vec![make_attribute_int("transB", 1)], ..Default::default() }; let mut shapes = HashMap::new(); // Use batch=1 so the BatchDim fallback path does not mask the // MatMul-branch decline we want to assert. shapes.insert("input".to_string(), vec![1, 145, 384]); shapes.insert("weight".to_string(), vec![1536, 384]); shapes.insert("bias".to_string(), vec![1536]); shapes.insert("output".to_string(), vec![1, 145, 1536]); let mut init_names = HashSet::new(); init_names.insert("weight".to_string()); init_names.insert("bias".to_string()); // Detector should decline the MatMul branch since the template // builder cannot handle biased Gemm, forcing fall-through. let got = detect_dim_split(&[node], &shapes, &init_names, 17); assert!( got.as_ref() .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)), "expected Gemm-with-bias MatMul decline, got {got:?}" ); } #[test] fn create_matmul_dim_template_uses_info_weight_name() { // Graph has two MatMul nodes referencing different weights. The // template builder must pick the node whose input is info.weight_name, // not the first MatMul encountered. let x = onnx_proto::make_tensor_value_info("input", TensorProto::FLOAT, &[4, 64]); let y = onnx_proto::make_tensor_value_info("output", TensorProto::FLOAT, &[4, 2048]); let w_small = onnx_proto::make_tensor( "w_small", TensorProto::FLOAT, &[64, 64], vec![0.0f32; 64 * 64], ); let w_big = onnx_proto::make_tensor( "w_big", TensorProto::FLOAT, &[64, 2048], vec![0.0f32; 64 * 2048], ); let n1 = onnx_proto::make_node( "MatMul", vec!["input".into(), "w_small".into()], vec!["mid".into()], vec![], ); let n2 = onnx_proto::make_node( "MatMul", vec!["mid".into(), "w_big".into()], vec!["output".into()], vec![], ); let graph = onnx_proto::make_graph( "two_matmul", vec![n1, n2], vec![x], vec![y], vec![w_small, w_big], ); let model = onnx_proto::make_model(graph, 13); let info = crate::schema::tiling::DimSplitInfo { slice_idx: 0, weight_name: Some("w_big".to_string()), input_name: "mid".to_string(), output_name: "output".to_string(), k_dim: 64, n_dim: 2048, k_chunks: 1, ..Default::default() }; let tmp = tempfile::tempdir().unwrap(); let tmpl_path = create_dim_split_template(&model, &info, tmp.path(), None).unwrap(); let tmpl_model = onnx_proto::load_model(&tmpl_path).unwrap(); let g = tmpl_model.graph.as_ref().unwrap(); let w = g.initializer.iter().find(|i| i.name == "W").unwrap(); // Template weight shape must reflect w_big (64, 2048), not w_small. assert_eq!(w.dims, vec![64, 2048]); } #[test] fn create_matmul_dim_template_disambiguates_shared_weight() { // Two MatMul ops share the same weight initializer (e.g. tied // weights). The template builder must select the op whose // input/output names match info, not the first node that happens // to reference the initializer. let x = onnx_proto::make_tensor_value_info("input", TensorProto::FLOAT, &[4, 64]); let y_a = onnx_proto::make_tensor_value_info("out_a", TensorProto::FLOAT, &[4, 32]); let y_b = onnx_proto::make_tensor_value_info("out_b", TensorProto::FLOAT, &[1, 32]); let shared_w = onnx_proto::make_tensor( "tied_w", TensorProto::FLOAT, &[64, 32], vec![0.0f32; 64 * 32], ); // First op: input -> tied_w -> out_a (shape [4, 32]) let n_a = onnx_proto::make_node( "MatMul", vec!["input".into(), "tied_w".into()], vec!["out_a".into()], vec![], ); // Second op: alt_in -> tied_w -> out_b (shape [1, 32]) let alt_in = onnx_proto::make_tensor_value_info("alt_in", TensorProto::FLOAT, &[1, 64]); let n_b = onnx_proto::make_node( "MatMul", vec!["alt_in".into(), "tied_w".into()], vec!["out_b".into()], vec![], ); let graph = onnx_proto::make_graph( "shared_weight", vec![n_a, n_b], vec![x, alt_in], vec![y_a, y_b], vec![shared_w], ); let model = onnx_proto::make_model(graph, 13); // Target the second op explicitly via input_name/output_name. let info = crate::schema::tiling::DimSplitInfo { slice_idx: 0, weight_name: Some("tied_w".to_string()), input_name: "alt_in".to_string(), output_name: "out_b".to_string(), k_dim: 64, n_dim: 32, k_chunks: 1, ..Default::default() }; let tmp = tempfile::tempdir().unwrap(); // Builder should succeed by binding the second op (the one whose // IO matches info), even though the first op also references the // same weight initializer. let tmpl_path = create_dim_split_template(&model, &info, tmp.path(), None).unwrap(); let tmpl_model = onnx_proto::load_model(&tmpl_path).unwrap(); let g = tmpl_model.graph.as_ref().unwrap(); let w = g.initializer.iter().find(|i| i.name == "W").unwrap(); assert_eq!(w.dims, vec![64, 32]); } fn make_maxpool_node( kernel: i64, stride: i64, pads: [i64; 4], ceil_mode: Option, ) -> NodeProto { let mut attrs = vec![ onnx_proto::make_attribute_ints("kernel_shape", &[kernel, kernel]), onnx_proto::make_attribute_ints("strides", &[stride, stride]), onnx_proto::make_attribute_ints("pads", &pads), ]; if let Some(cm) = ceil_mode { attrs.push(onnx_proto::make_attribute_int("ceil_mode", cm)); } onnx_proto::make_node( "MaxPool", vec!["input".into()], vec!["output".into()], attrs, ) } #[test] fn pool_params_valid() { let node = make_maxpool_node(2, 2, [0, 0, 0, 0], None); let pp = PoolParams::from_node(&node, 0); assert!(pp.is_some()); let pp = pp.unwrap(); assert_eq!(pp.kernel, [2, 2]); assert_eq!(pp.stride, [2, 2]); } #[test] fn pool_params_rejects_ceil_mode() { let node = make_maxpool_node(2, 2, [0, 0, 0, 0], Some(1)); assert!(PoolParams::from_node(&node, 0).is_none()); } #[test] fn pool_params_accepts_ceil_mode_zero() { let node = make_maxpool_node(2, 2, [0, 0, 0, 0], Some(0)); assert!(PoolParams::from_node(&node, 0).is_some()); } #[test] fn pool_params_rejects_auto_pad() { let mut attrs = vec![ onnx_proto::make_attribute_ints("kernel_shape", &[2, 2]), onnx_proto::make_attribute_ints("strides", &[2, 2]), ]; attrs.push(onnx_proto::AttributeProto { name: "auto_pad".into(), s: b"SAME_UPPER".to_vec(), ..Default::default() }); let node = onnx_proto::make_node( "MaxPool", vec!["input".into()], vec!["output".into()], attrs, ); assert!(PoolParams::from_node(&node, 0).is_none()); } #[test] fn pool_params_rejects_non_maxpool() { let node = onnx_proto::make_node( "Conv", vec!["input".into()], vec!["output".into()], vec![onnx_proto::make_attribute_ints("kernel_shape", &[3, 3])], ); assert!(PoolParams::from_node(&node, 0).is_none()); } fn make_elementwise_model(op: &str, shape: &[i64]) -> ModelProto { let x = onnx_proto::make_tensor_value_info("input", TensorProto::FLOAT, shape); let y = onnx_proto::make_tensor_value_info("output", TensorProto::FLOAT, shape); let node = onnx_proto::make_node(op, vec!["input".into()], vec!["output".into()], vec![]); let graph = onnx_proto::make_graph("test", vec![node], vec![x], vec![y], vec![]); onnx_proto::make_model(graph, 13) } #[test] fn fixed_segments_too_small_returns_none() { let model = make_elementwise_model("Relu", &[1, 3, 8, 8]); assert!(detect_elementwise_fixed_segments(model.graph.as_ref().unwrap()).is_none()); } #[test] fn fixed_segments_detects_large_tensor() { let model = make_elementwise_model("Relu", &[1, 16, 64, 64]); let graph = model.graph.as_ref().unwrap(); let det = detect_elementwise_fixed_segments(graph); assert!(det.is_some()); if let Some(TilingDetection::FixedSegment { segment_size, total_elements, num_segments, .. }) = det { assert_eq!(total_elements, 16 * 64 * 64); assert_eq!(segment_size, ELEMENTWISE_SEGMENT_SIZE); assert_eq!( num_segments, (total_elements + segment_size - 1) / segment_size ); } else { panic!("expected FixedSegment variant"); } } #[test] fn fixed_segments_rejects_zero_dim() { let model = make_elementwise_model("Relu", &[1, 0, 64, 64]); assert!(detect_elementwise_fixed_segments(model.graph.as_ref().unwrap()).is_none()); } #[test] fn fixed_segments_rejects_non_elementwise() { let x = onnx_proto::make_tensor_value_info("input", TensorProto::FLOAT, &[1, 16, 64, 64]); let y = onnx_proto::make_tensor_value_info("output", TensorProto::FLOAT, &[1, 16, 64, 64]); let node = onnx_proto::make_node( "Softmax", vec!["input".into()], vec!["output".into()], vec![], ); let graph = onnx_proto::make_graph("test", vec![node], vec![x], vec![y], vec![]); let model = onnx_proto::make_model(graph, 13); assert!(detect_elementwise_fixed_segments(model.graph.as_ref().unwrap()).is_none()); } #[test] fn create_pool_tile_slice_valid() { let x = onnx_proto::make_tensor_value_info("input", TensorProto::FLOAT, &[1, 3, 64, 64]); let y = onnx_proto::make_tensor_value_info("output", TensorProto::FLOAT, &[1, 3, 32, 32]); let node = make_maxpool_node(2, 2, [0, 0, 0, 0], None); let graph = onnx_proto::make_graph("pool", vec![node], vec![x], vec![y], vec![]); let model = onnx_proto::make_model(graph, 13); let tmp = tempfile::tempdir().unwrap(); let result = create_pool_tile_slice(&model, 16, 0, tmp.path()); assert!(result.is_ok()); let r = result.unwrap(); assert!(r.path.contains("tile.onnx")); } #[test] fn create_pool_tile_slice_rejects_zero_tile() { let x = onnx_proto::make_tensor_value_info("input", TensorProto::FLOAT, &[1, 3, 64, 64]); let y = onnx_proto::make_tensor_value_info("output", TensorProto::FLOAT, &[1, 3, 32, 32]); let node = make_maxpool_node(2, 2, [0, 0, 0, 0], None); let graph = onnx_proto::make_graph("pool", vec![node], vec![x], vec![y], vec![]); let model = onnx_proto::make_model(graph, 13); let tmp = tempfile::tempdir().unwrap(); assert!(create_pool_tile_slice(&model, 0, 0, tmp.path()).is_err()); } #[test] fn create_pool_tile_slice_no_pool_node() { let x = onnx_proto::make_tensor_value_info("input", TensorProto::FLOAT, &[1, 3, 64, 64]); let y = onnx_proto::make_tensor_value_info("output", TensorProto::FLOAT, &[1, 3, 64, 64]); let node = onnx_proto::make_node("Relu", vec!["input".into()], vec!["output".into()], vec![]); let graph = onnx_proto::make_graph("no_pool", vec![node], vec![x], vec![y], vec![]); let model = onnx_proto::make_model(graph, 13); let tmp = tempfile::tempdir().unwrap(); assert!(create_pool_tile_slice(&model, 16, 0, tmp.path()).is_err()); } #[test] fn estimate_slice_constraints_clamps_symbolic_dimensions() { // ONNX serializes dynamic axes as -1 and placeholder axes as 0. // Both must be clamped to 1 before forwarding to the jstprove // estimator, otherwise product(shape) multiplies by zero and // collapses the op's cost contribution to 0. let node = NodeProto { op_type: "MatMul".to_string(), input: vec!["input".to_string(), "weight".to_string()], output: vec!["output".to_string()], ..Default::default() }; let mut symbolic_shapes = HashMap::new(); symbolic_shapes.insert("input".to_string(), vec![-1, 64]); symbolic_shapes.insert("weight".to_string(), vec![64, 128]); symbolic_shapes.insert("output".to_string(), vec![0, 128]); let mut concrete_shapes = HashMap::new(); concrete_shapes.insert("input".to_string(), vec![1, 64]); concrete_shapes.insert("weight".to_string(), vec![64, 128]); concrete_shapes.insert("output".to_string(), vec![1, 128]); let nodes = [node]; let symbolic_cost = estimate_slice_constraints(&nodes, &symbolic_shapes); let concrete_cost = estimate_slice_constraints(&nodes, &concrete_shapes); assert!( symbolic_cost > 0, "symbolic dims must not collapse cost to zero" ); assert_eq!( symbolic_cost, concrete_cost, "batch -1 and batch 0 must clamp to 1 and match concrete batch 1" ); } } ================================================ FILE: crates/dsperse/src/slicer/combiner.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use super::onnx_proto::{self, ModelProto, TensorProto, ValueInfoProto}; use crate::error::{DsperseError, Result}; use crate::schema::metadata::ModelMetadata; pub fn materialize_combined_model( model: &ModelProto, metadata: &ModelMetadata, traced_shapes: &HashMap>, traced_types: Option<&HashMap>, ) -> Result { let mut combined = model.clone(); let graph = combined .graph .as_mut() .ok_or_else(|| DsperseError::Slicer("model.graph is None".into()))?; let existing_outputs: HashSet = graph.output.iter().map(|o| o.name.clone()).collect(); let all_node_outputs: HashSet = graph .node .iter() .flat_map(|n| n.output.iter().cloned()) .collect(); let mut new_outputs: Vec = Vec::new(); let mut added: HashSet = HashSet::new(); { let vi_map = onnx_proto::build_value_info_map(graph); for slice in &metadata.slices { for output_name in &slice.dependencies.output { if existing_outputs.contains(output_name) || added.contains(output_name) { continue; } if !all_node_outputs.contains(output_name) { tracing::warn!( tensor = %output_name, slice = slice.index, "slice output not produced by any node in original graph, skipping" ); continue; } if let Some(vi) = resolve_value_info(output_name, &vi_map, traced_shapes, traced_types)? { new_outputs.push(vi); added.insert(output_name.clone()); } } for input_name in &slice.dependencies.filtered_inputs { if existing_outputs.contains(input_name) || added.contains(input_name) { continue; } if !all_node_outputs.contains(input_name) { tracing::debug!( tensor = %input_name, slice = slice.index, "slice filtered_input not produced by any node in original graph, skipping" ); continue; } if let Some(vi) = resolve_value_info(input_name, &vi_map, traced_shapes, traced_types)? { new_outputs.push(vi); added.insert(input_name.clone()); } } } } graph.output.extend(new_outputs); tracing::info!( intermediate_outputs = added.len(), total_outputs = graph.output.len(), "combined model with slice boundary outputs" ); Ok(combined) } const ONNX_STRING_DATATYPE: i32 = 8; const NON_NUMERIC_TENSOR_TYPES: &[i32] = &[ONNX_STRING_DATATYPE]; fn resolve_value_info( name: &str, vi_map: &HashMap, traced_shapes: &HashMap>, traced_types: Option<&HashMap>, ) -> Result> { if let Some(vi) = vi_map.get(name) { let elem_type = onnx_proto::elem_type_from_value_info(vi).unwrap_or(TensorProto::FLOAT); if NON_NUMERIC_TENSOR_TYPES.contains(&elem_type) { return Ok(None); } return Ok(Some((*vi).clone())); } let shape = traced_shapes.get(name).ok_or_else(|| { DsperseError::Slicer(format!( "no shape info for combined model output tensor '{name}'" )) })?; let elem_type = traced_types .and_then(|t| t.get(name).copied()) .unwrap_or(TensorProto::FLOAT); if NON_NUMERIC_TENSOR_TYPES.contains(&elem_type) { return Ok(None); } Ok(Some(onnx_proto::make_tensor_value_info( name, elem_type, shape, ))) } pub fn ensure_combined_materialized( slices_dir: &Path, metadata: &ModelMetadata, ) -> Result { let output_path = slices_dir.join("combined.onnx"); if output_path.exists() { return Ok(output_path); } materialize_combined_to_disk(slices_dir, metadata) } pub fn materialize_combined_to_disk( slices_dir: &Path, metadata: &ModelMetadata, ) -> Result { let traced_shapes = metadata.traced_shapes.as_ref().ok_or_else(|| { DsperseError::Slicer("metadata missing traced_shapes for combined model".into()) })?; let traced_types = metadata.traced_types.as_ref(); let original_path = metadata.original_model_path.as_ref().ok_or_else(|| { DsperseError::Slicer("metadata missing original_model_path for combined model".into()) })?; let model_path = if Path::new(original_path).is_absolute() { std::path::PathBuf::from(original_path) } else { slices_dir.join(original_path) }; let mut model = onnx_proto::load_model(&model_path)?; onnx_proto::normalize_opset(&mut model); let combined = materialize_combined_model(&model, metadata, traced_shapes, traced_types)?; let dest = slices_dir.join("combined.onnx"); onnx_proto::save_model(&combined, &dest)?; tracing::info!(path = %dest.display(), "materialized combined ONNX"); Ok(dest) } #[cfg(test)] mod tests { use super::*; use crate::schema::metadata::{ Dependencies, ModelMetadata, SliceMetadata, SliceShapeWrapper, TensorShape, }; fn make_test_model( node_output_types: HashMap, traced_shapes: HashMap>, ) -> (ModelProto, ModelMetadata) { let graph = onnx_proto::GraphProto { node: vec![onnx_proto::NodeProto { op_type: "Identity".to_string(), input: vec!["input".to_string()], output: vec![ "float_tensor".to_string(), "bool_tensor".to_string(), "string_tensor".to_string(), "int_tensor".to_string(), ], ..Default::default() }], input: vec![onnx_proto::make_tensor_value_info( "input", TensorProto::FLOAT, &[1, 3, 8, 8], )], output: vec![onnx_proto::make_tensor_value_info( "model_output", TensorProto::FLOAT, &[1, 3, 8, 8], )], ..Default::default() }; let model = onnx_proto::make_model(graph, 13); let metadata = ModelMetadata { slices: vec![SliceMetadata { index: 0, filename: "s0.onnx".to_string(), path: "s0.onnx".to_string(), relative_path: "s0.onnx".to_string(), shape: SliceShapeWrapper { tensor_shape: TensorShape { input: vec![], output: vec![], }, }, dependencies: Dependencies { input: vec![], filtered_inputs: vec![], output: vec![ "float_tensor".to_string(), "bool_tensor".to_string(), "string_tensor".to_string(), "int_tensor".to_string(), ], }, ..Default::default() }], traced_shapes: Some(traced_shapes.clone()), traced_types: Some(node_output_types), ..Default::default() }; (model, metadata) } #[test] fn bool_outputs_included_in_combined_model() { let mut node_output_types = HashMap::new(); node_output_types.insert("float_tensor".to_string(), TensorProto::FLOAT); node_output_types.insert("bool_tensor".to_string(), TensorProto::BOOL); node_output_types.insert("string_tensor".to_string(), ONNX_STRING_DATATYPE); node_output_types.insert("int_tensor".to_string(), TensorProto::INT64); let mut traced_shapes = HashMap::new(); traced_shapes.insert("float_tensor".to_string(), vec![1, 3, 8, 8]); traced_shapes.insert("bool_tensor".to_string(), vec![1, 3, 8, 8]); traced_shapes.insert("string_tensor".to_string(), vec![1, 3, 8, 8]); traced_shapes.insert("int_tensor".to_string(), vec![1, 3, 8, 8]); let (model, metadata) = make_test_model(node_output_types, traced_shapes.clone()); let traced_types = metadata.traced_types.as_ref(); let combined = materialize_combined_model(&model, &metadata, &traced_shapes, traced_types).unwrap(); let graph = combined.graph.as_ref().unwrap(); let float_vi = graph.output.iter().find(|o| o.name == "float_tensor"); assert!(float_vi.is_some()); let bool_vi = graph.output.iter().find(|o| o.name == "bool_tensor"); assert!(bool_vi.is_some()); let string_vi = graph.output.iter().find(|o| o.name == "string_tensor"); assert!( string_vi.is_none(), "string tensors should be excluded from combined outputs" ); let int_vi = graph.output.iter().find(|o| o.name == "int_tensor"); assert!(int_vi.is_some()); } #[test] fn combined_model_has_intermediate_outputs() { let mut traced_shapes = HashMap::new(); traced_shapes.insert("float_tensor".to_string(), vec![1, 3, 8, 8]); traced_shapes.insert("bool_tensor".to_string(), vec![1]); traced_shapes.insert("string_tensor".to_string(), vec![1]); traced_shapes.insert("int_tensor".to_string(), vec![2, 4]); let mut types = HashMap::new(); types.insert("float_tensor".to_string(), TensorProto::FLOAT); types.insert("bool_tensor".to_string(), TensorProto::BOOL); types.insert("int_tensor".to_string(), TensorProto::INT64); let (model, metadata) = make_test_model(types, traced_shapes.clone()); let traced_types = metadata.traced_types.as_ref(); let combined = materialize_combined_model(&model, &metadata, &traced_shapes, traced_types).unwrap(); let graph = combined.graph.as_ref().unwrap(); assert!( graph.output.len() > 1, "combined model should have intermediate outputs" ); } #[test] fn combined_model_to_disk_roundtrip() { let dir = tempfile::tempdir().unwrap(); let slices_dir = dir.path(); let mut traced_shapes = HashMap::new(); traced_shapes.insert("float_tensor".to_string(), vec![1, 3, 8, 8]); traced_shapes.insert("bool_tensor".to_string(), vec![1]); traced_shapes.insert("string_tensor".to_string(), vec![1]); traced_shapes.insert("int_tensor".to_string(), vec![2, 4]); let mut types = HashMap::new(); types.insert("float_tensor".to_string(), TensorProto::FLOAT); types.insert("bool_tensor".to_string(), TensorProto::BOOL); types.insert("int_tensor".to_string(), TensorProto::INT64); let (model, mut metadata) = make_test_model(types, traced_shapes); metadata.original_model_path = Some("model.onnx".to_string()); let model_path = slices_dir.join("model.onnx"); onnx_proto::save_model(&model, &model_path).unwrap(); let meta_path = slices_dir.join("metadata.msgpack"); metadata.save(&meta_path).unwrap(); let dest = materialize_combined_to_disk(slices_dir, &metadata).unwrap(); assert!(dest.exists()); let loaded = onnx_proto::load_model(&dest).unwrap(); let graph = loaded.graph.as_ref().unwrap(); assert!( graph.output.len() > 1, "reloaded combined model should have intermediate outputs" ); } #[test] fn ensure_combined_is_idempotent() { let dir = tempfile::tempdir().unwrap(); let slices_dir = dir.path(); let mut traced_shapes = HashMap::new(); traced_shapes.insert("float_tensor".to_string(), vec![1, 3, 8, 8]); traced_shapes.insert("bool_tensor".to_string(), vec![1]); traced_shapes.insert("string_tensor".to_string(), vec![1]); traced_shapes.insert("int_tensor".to_string(), vec![2, 4]); let mut types = HashMap::new(); types.insert("float_tensor".to_string(), TensorProto::FLOAT); types.insert("bool_tensor".to_string(), TensorProto::BOOL); types.insert("int_tensor".to_string(), TensorProto::INT64); let (model, mut metadata) = make_test_model(types, traced_shapes); metadata.original_model_path = Some("model.onnx".to_string()); let model_path = slices_dir.join("model.onnx"); onnx_proto::save_model(&model, &model_path).unwrap(); let meta_path = slices_dir.join("metadata.msgpack"); metadata.save(&meta_path).unwrap(); let dest1 = materialize_combined_to_disk(slices_dir, &metadata).unwrap(); let dest2 = materialize_combined_to_disk(slices_dir, &metadata).unwrap(); assert_eq!(dest1, dest2); } } ================================================ FILE: crates/dsperse/src/slicer/layernorm_fuse.rs ================================================ use std::collections::{HashMap, HashSet}; use super::onnx_proto::{ AttributeProto, ModelProto, NodeProto, TensorProto, tensor_to_f32, tensor_to_i64, }; pub fn fuse_inline_layernorms( model: &mut ModelProto, traced_shapes: &mut HashMap>, ) -> usize { let graph = match model.graph.as_mut() { Some(g) => g, None => return 0, }; let initializers: HashMap = graph .initializer .iter() .map(|t| (t.name.clone(), t.clone())) .collect(); let producers: HashMap = graph .node .iter() .enumerate() .flat_map(|(i, n)| { n.output .iter() .filter(|o| !o.is_empty()) .map(move |o| (o.clone(), i)) }) .collect(); let mut consumers: HashMap> = HashMap::new(); for (i, n) in graph.node.iter().enumerate() { for inp in &n.input { if !inp.is_empty() { consumers.entry(inp.clone()).or_default().push(i); } } } let mut drop: HashSet = HashSet::new(); let mut insertions: Vec<(usize, Vec, Vec)> = Vec::new(); let mut fused_id = 0usize; for (mean_idx, mean_node) in graph.node.iter().enumerate() { if drop.contains(&mean_idx) || mean_node.op_type != "ReduceMean" { continue; } let Some(m) = try_match_layernorm( mean_idx, mean_node, &graph.node, &producers, &consumers, &initializers, traced_shapes, &drop, ) else { continue; }; let (nodes, inits, shapes) = emit_replacement(&m, fused_id, &initializers); for (name, shape) in shapes { traced_shapes.insert(name, shape); } fused_id += 1; drop.extend(m.nodes_to_drop.iter().copied()); insertions.push((mean_idx, nodes, inits)); } let fused = insertions.len(); if fused == 0 { return 0; } for (_, _, inits) in &insertions { for t in inits { graph.initializer.push(t.clone()); } } let insertion_map: HashMap> = insertions .into_iter() .map(|(idx, nodes, _)| (idx, nodes)) .collect(); let mut new_nodes: Vec = Vec::with_capacity(graph.node.len()); for (i, n) in graph.node.drain(..).enumerate() { if let Some(inserts) = insertion_map.get(&i) { new_nodes.extend(inserts.iter().cloned()); continue; } if drop.contains(&i) { continue; } new_nodes.push(n); } graph.node = new_nodes; fused } struct MatchedPattern { x_name: String, axes: Vec, rank: usize, x_shape: Vec, eps: f32, scale_init: Option, bias_init: Option, output_name: String, nodes_to_drop: Vec, } #[allow(clippy::too_many_arguments, clippy::too_many_lines)] fn try_match_layernorm( mean_idx: usize, mean_node: &NodeProto, nodes: &[NodeProto], producers: &HashMap, consumers: &HashMap>, initializers: &HashMap, traced_shapes: &HashMap>, drop: &HashSet, ) -> Option { let raw_axes = reduce_axes(mean_node, initializers)?; if get_keepdims(mean_node).unwrap_or(1) != 1 { return None; } let x_name = mean_node.input.first()?.clone(); let mean_out = mean_node.output.first()?.clone(); let sub_idx = find_unique_consumer(consumers, &mean_out, "Sub", nodes, drop)?; let sub_node = &nodes[sub_idx]; if sub_node.input.len() < 2 || sub_node.input.first()? != &x_name || sub_node.input.get(1)? != &mean_out { return None; } let centered = sub_node.output.first()?.clone(); let sq_idx = find_square_consumer(consumers, ¢ered, nodes, initializers, drop)?; let sq_node = &nodes[sq_idx]; let sq_out = sq_node.output.first()?.clone(); let mean2_idx = find_unique_consumer(consumers, &sq_out, "ReduceMean", nodes, drop)?; let mean2_node = &nodes[mean2_idx]; let raw_axes2 = reduce_axes(mean2_node, initializers)?; if raw_axes2 != raw_axes { return None; } if get_keepdims(mean2_node).unwrap_or(1) != 1 { return None; } let var_out = mean2_node.output.first()?.clone(); let add_idx = find_unique_consumer(consumers, &var_out, "Add", nodes, drop)?; let add_node = &nodes[add_idx]; let eps = extract_binary_const_scalar(add_node, &var_out, initializers)?; let var_eps = add_node.output.first()?.clone(); let sqrt_idx = find_unique_consumer(consumers, &var_eps, "Sqrt", nodes, drop)?; let sqrt_node = &nodes[sqrt_idx]; let std_out = sqrt_node.output.first()?.clone(); let div_idx = find_unique_consumer(consumers, &std_out, "Div", nodes, drop)?; let div_node = &nodes[div_idx]; if div_node.input.len() < 2 || div_node.input.first()? != ¢ered || div_node.input.get(1)? != &std_out { return None; } let norm_out = div_node.output.first()?.clone(); let mut nodes_to_drop = vec![ mean_idx, sub_idx, sq_idx, mean2_idx, add_idx, sqrt_idx, div_idx, ]; let mut output_name = norm_out.clone(); let mut scale_init: Option = None; let mut bias_init: Option = None; if let Some(mul_idx) = find_unique_consumer(consumers, &norm_out, "Mul", nodes, drop) { let mul_node = &nodes[mul_idx]; if let Some(scale) = other_input_if_init(mul_node, &norm_out, initializers) { scale_init = Some(scale); output_name = mul_node.output.first()?.clone(); nodes_to_drop.push(mul_idx); if let Some(add2_idx) = find_unique_consumer(consumers, &output_name, "Add", nodes, drop) { let add2_node = &nodes[add2_idx]; if let Some(bias) = other_input_if_init(add2_node, &output_name, initializers) { bias_init = Some(bias); output_name = add2_node.output.first()?.clone(); nodes_to_drop.push(add2_idx); } } } } // Soundness check: every intermediate tensor we are about to drop // (mean_out, centered, sq_out, var_out, var_eps, std_out, plus the // pre-affine norm_out when scale/bias are present) must have all // its live consumers inside nodes_to_drop. Otherwise some // downstream node still reads the intermediate and fusing would // disconnect it. let drop_set: HashSet = nodes_to_drop.iter().copied().collect(); let mut intermediates: Vec<&str> = vec![ mean_out.as_str(), centered.as_str(), sq_out.as_str(), var_out.as_str(), var_eps.as_str(), std_out.as_str(), ]; if scale_init.is_some() { intermediates.push(norm_out.as_str()); } for tname in intermediates { if let Some(list) = consumers.get(tname) { for &idx in list { if drop.contains(&idx) || drop_set.contains(&idx) { continue; } return None; } } } let x_shape = resolve_shape(&x_name, traced_shapes, initializers, nodes, producers)?; let rank = x_shape.len(); if rank == 0 { return None; } let axes: Vec = raw_axes.iter().map(|&a| normalize_axis(a, rank)).collect(); for &a in &axes { if a >= rank { return None; } // Reject dynamic / unresolved dims along the reduction axes: the // fused LayerNormalization circuit needs a concrete lane_size // and consumers of m.x_shape[a] later cast the dim to usize, // which silently wraps negative sentinels into huge values. if x_shape[a] <= 0 { return None; } } Some(MatchedPattern { x_name, axes, rank, x_shape, eps, scale_init, bias_init, output_name, nodes_to_drop, }) } fn resolve_shape( name: &str, traced_shapes: &HashMap>, initializers: &HashMap, _nodes: &[NodeProto], _producers: &HashMap, ) -> Option> { if let Some(s) = traced_shapes.get(name) && !s.is_empty() { return Some(s.clone()); } if let Some(t) = initializers.get(name) { return Some(t.dims.clone()); } None } fn reduce_axes(node: &NodeProto, initializers: &HashMap) -> Option> { if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") && !attr.ints.is_empty() { return Some(attr.ints.clone()); } if let Some(name) = node.input.get(1) && let Some(t) = initializers.get(name) { let v = tensor_to_i64(t); if !v.is_empty() { return Some(v); } } None } fn get_keepdims(node: &NodeProto) -> Option { node.attribute .iter() .find(|a| a.name == "keepdims") .map(|a| a.i) } fn find_unique_consumer( consumers: &HashMap>, tensor: &str, op_type: &str, nodes: &[NodeProto], drop: &HashSet, ) -> Option { let list = consumers.get(tensor)?; let live: Vec = list.iter().copied().filter(|i| !drop.contains(i)).collect(); if live.len() != 1 { return None; } let idx = live[0]; (nodes[idx].op_type == op_type).then_some(idx) } fn find_square_consumer( consumers: &HashMap>, tensor: &str, nodes: &[NodeProto], initializers: &HashMap, drop: &HashSet, ) -> Option { // The centered tensor in the inline-LN pattern has TWO legitimate // consumers: Pow / Mul (for the variance branch) AND Div (for the // normalization branch). Both belong to the fusion -- don't reject // them as orphan consumers. Final orphan-leak check happens after // the whole pattern matches in try_match_layernorm. let list = consumers.get(tensor)?; for &idx in list.iter().filter(|i| !drop.contains(i)) { let n = &nodes[idx]; match n.op_type.as_str() { "Pow" => { if n.input.len() >= 2 && n.input.first().map(String::as_str) == Some(tensor) && pow_exponent_is_two(n.input.get(1)?, initializers) { return Some(idx); } } "Mul" => { if n.input.len() == 2 && n.input.iter().all(|i| i == tensor) { return Some(idx); } } _ => {} } } None } fn pow_exponent_is_two(name: &str, initializers: &HashMap) -> bool { let Some(t) = initializers.get(name) else { return false; }; let f = tensor_to_f32(t); if let Some(&v) = f.first() && (v - 2.0).abs() < f32::EPSILON { return true; } let i = tensor_to_i64(t); matches!(i.first(), Some(&2)) } fn extract_binary_const_scalar( node: &NodeProto, non_const_input: &str, initializers: &HashMap, ) -> Option { if node.input.len() != 2 { return None; } let (a, b) = (node.input.first()?, node.input.get(1)?); let other_name = if a.as_str() == non_const_input { b } else if b.as_str() == non_const_input { a } else { return None; }; let t = initializers.get(other_name)?; tensor_to_f32(t).first().copied() } fn other_input_if_init( node: &NodeProto, non_const_input: &str, initializers: &HashMap, ) -> Option { if node.input.len() != 2 { return None; } let a = node.input.first()?.clone(); let b = node.input.get(1)?.clone(); let other = if a == non_const_input { b } else if b == non_const_input { a } else { return None; }; initializers.get(&other).map(|_| other) } type ReplacementShapes = Vec<(String, Vec)>; type Replacement = (Vec, Vec, ReplacementShapes); fn emit_replacement( m: &MatchedPattern, fused_id: usize, initializers: &HashMap, ) -> Replacement { let rank = m.rank; let axes_set: HashSet = m.axes.iter().copied().collect(); let mut forward_perm: Vec = (0..rank) .filter(|d| !axes_set.contains(d)) .map(|d| d as i64) .collect(); for &a in &m.axes { forward_perm.push(a as i64); } let mut inverse_perm: Vec = vec![0; rank]; for (new_pos, &old_pos) in forward_perm.iter().enumerate() { inverse_perm[old_pos as usize] = new_pos as i64; } let lane_size: usize = m.axes.iter().map(|&a| m.x_shape[a] as usize).product(); let prefix = format!("/__dsperse/fused_ln_{fused_id}"); let xt_name = format!("{prefix}/xt"); let yt_name = format!("{prefix}/yt"); let (scale_name, scale_init_opt) = materialize_1d_initializer( &format!("{prefix}/scale"), m.scale_init.as_deref(), initializers, lane_size, 1.0, ); let (bias_name, bias_init_opt) = materialize_1d_initializer( &format!("{prefix}/bias"), m.bias_init.as_deref(), initializers, lane_size, 0.0, ); let mut nodes = Vec::new(); nodes.push(NodeProto { name: format!("{prefix}/Transpose_in"), op_type: "Transpose".to_string(), input: vec![m.x_name.clone()], output: vec![xt_name.clone()], attribute: vec![int_list_attr("perm", &forward_perm)], ..Default::default() }); let ln_axis = (rank - m.axes.len()) as i64; nodes.push(NodeProto { name: format!("{prefix}/LayerNormalization"), op_type: "LayerNormalization".to_string(), input: vec![xt_name, scale_name, bias_name], output: vec![yt_name.clone()], attribute: vec![int_attr("axis", ln_axis), float_attr("epsilon", m.eps)], ..Default::default() }); nodes.push(NodeProto { name: format!("{prefix}/Transpose_out"), op_type: "Transpose".to_string(), input: vec![yt_name], output: vec![m.output_name.clone()], attribute: vec![int_list_attr("perm", &inverse_perm)], ..Default::default() }); let mut inits = Vec::new(); inits.extend(scale_init_opt); inits.extend(bias_init_opt); let xt_shape: Vec = forward_perm .iter() .map(|&p| m.x_shape[p as usize]) .collect(); let yt_shape = xt_shape.clone(); let shapes = vec![ (format!("{prefix}/xt"), xt_shape), (format!("{prefix}/yt"), yt_shape), ]; (nodes, inits, shapes) } fn materialize_1d_initializer( new_name: &str, source: Option<&str>, initializers: &HashMap, lane_size: usize, default_fill: f32, ) -> (String, Option) { let Some(src) = source else { return ( new_name.to_string(), Some(const_vector(new_name, lane_size, default_fill)), ); }; let Some(t) = initializers.get(src) else { return ( new_name.to_string(), Some(const_vector(new_name, lane_size, default_fill)), ); }; let elems = tensor_to_f32(t); if elems.len() == lane_size && t.dims.len() == 1 { return (src.to_string(), None); } let vals: Vec = if elems.len() == lane_size { elems } else if elems.len() == 1 { vec![elems[0]; lane_size] } else { return ( new_name.to_string(), Some(const_vector(new_name, lane_size, default_fill)), ); }; (new_name.to_string(), Some(make_f32_vector(new_name, &vals))) } fn const_vector(name: &str, len: usize, fill: f32) -> TensorProto { make_f32_vector(name, &vec![fill; len]) } fn make_f32_vector(name: &str, vals: &[f32]) -> TensorProto { TensorProto { name: name.to_string(), data_type: TensorProto::FLOAT, dims: vec![vals.len() as i64], float_data: vals.to_vec(), ..Default::default() } } fn normalize_axis(axis: i64, rank: usize) -> usize { if axis < 0 { (rank as i64 + axis) as usize } else { axis as usize } } fn int_attr(name: &str, v: i64) -> AttributeProto { AttributeProto { name: name.to_string(), r#type: 2, i: v, ..Default::default() } } fn float_attr(name: &str, v: f32) -> AttributeProto { AttributeProto { name: name.to_string(), r#type: 1, f: v, ..Default::default() } } fn int_list_attr(name: &str, vals: &[i64]) -> AttributeProto { AttributeProto { name: name.to_string(), r#type: 7, ints: vals.to_vec(), ..Default::default() } } ================================================ FILE: crates/dsperse/src/slicer/materializer.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use super::autotiler::{self, ChannelSplitParams}; use super::onnx_proto::{self, GraphProto, ModelProto, NodeProto, TensorProto, ValueInfoProto}; use super::onnx_slicer::broadcast_shapes; use crate::error::{DsperseError, Result}; use crate::schema::metadata::ModelMetadata; const MAX_BACKWARD_DEPTH: usize = 64; fn resolve_shape_backward( tensor_name: &str, graph: &GraphProto, traced_shapes: &HashMap>, ) -> Option> { resolve_shape_backward_inner(tensor_name, graph, traced_shapes, 0) } fn resolve_shape_backward_inner( tensor_name: &str, graph: &GraphProto, traced_shapes: &HashMap>, depth: usize, ) -> Option> { if depth > MAX_BACKWARD_DEPTH { return None; } if let Some(s) = traced_shapes.get(tensor_name) { return Some(s.clone()); } if let Some(vi) = graph.value_info.iter().find(|v| v.name == tensor_name) && let Some(shape) = onnx_proto::shape_from_value_info(vi) { return Some(shape); } for init in &graph.initializer { if init.name == tensor_name { return Some(init.dims.to_vec()); } } let producer = graph .node .iter() .find(|n| n.output.contains(&tensor_name.to_string()))?; let op = producer.op_type.as_str(); if super::is_shape_preserving(op) { let inp = producer.input.first()?; return resolve_shape_backward_inner(inp, graph, traced_shapes, depth + 1); } if op == "Shape" { let inp = producer.input.first()?; let in_shape = resolve_shape_backward_inner(inp, graph, traced_shapes, depth + 1)?; return Some(vec![in_shape.len() as i64]); } if super::is_binary_arithmetic(op) { let resolved: Vec> = producer .input .iter() .filter_map(|inp| resolve_shape_backward_inner(inp, graph, traced_shapes, depth + 1)) .collect(); let refs: Vec<&Vec> = resolved.iter().collect(); if let Some(broadcasted) = broadcast_shapes(&refs) { return Some(broadcasted); } } None } pub fn materialize_slice_model( model: &ModelProto, slice_points: &[usize], traced_shapes: &HashMap>, traced_types: &HashMap, slice_idx: usize, ) -> Result { let graph = model .graph .as_ref() .ok_or_else(|| DsperseError::Slicer("model.graph is None".into()))?; let total_nodes = graph.node.len(); let segment_ranges = super::build_segment_ranges(slice_points, Some(total_nodes)); let &(start, end) = segment_ranges.get(slice_idx).ok_or_else(|| { DsperseError::Slicer(format!( "slice index {slice_idx} out of range (have {} segments)", segment_ranges.len() )) })?; let init_map: HashMap<&str, &TensorProto> = graph .initializer .iter() .map(|i| (i.name.as_str(), i)) .collect(); let vi_map = onnx_proto::build_value_info_map(graph); let init_types: HashMap<&str, i32> = graph .initializer .iter() .map(|i| (i.name.as_str(), i.data_type)) .collect(); let node_output_types = build_node_output_types(graph); let future_deps = compute_future_dependencies(graph, &segment_ranges, &init_map); let constant_producers: HashMap = graph .node .iter() .filter(|n| n.op_type == "Constant") .flat_map(|n| { n.output.iter().filter_map(|out| { n.attribute .iter() .find(|a| a.name == "value") .and_then(|a| a.t.as_ref()) .map(|t| (out.clone(), t)) }) }) .collect(); let nodes: Vec = graph.node[start..end].to_vec(); let seg_outputs: HashSet = nodes .iter() .flat_map(|n| n.output.iter().cloned()) .collect(); let seg_inputs_set: HashSet = nodes .iter() .flat_map(|n| { let mut inputs: Vec = n.input.iter().filter(|s| !s.is_empty()).cloned().collect(); if super::is_control_flow(&n.op_type) { let outer_refs = super::collect_subgraph_outer_refs(n, graph); inputs.extend(outer_refs); } inputs }) .collect(); let future = future_deps.get(&slice_idx).cloned().unwrap_or_default(); let query = SegmentQuery { nodes: &nodes, seg_outputs: &seg_outputs, seg_inputs_set: &seg_inputs_set, future_inputs: &future, }; let ctx = ShapeContext { graph, init_map: &init_map, vi_map: &vi_map, traced_shapes, traced_types, init_types: &init_types, node_output_types: &node_output_types, constant_producers: &constant_producers, }; let (inputs, outputs, initializers) = get_segment_details(&query, &ctx)?; let opset_version = model .opset_import .iter() .find(|o| o.domain.is_empty() || o.domain == "ai.onnx") .map(|o| o.version) .unwrap_or(13); let seg_graph = onnx_proto::make_graph( &format!("segment_{slice_idx}_graph"), nodes, inputs, outputs, initializers, ); Ok(onnx_proto::make_model(seg_graph, opset_version)) } pub fn materialize_slice_to_disk( model: &ModelProto, slice_points: &[usize], traced_shapes: &HashMap>, traced_types: &HashMap, slice_idx: usize, output_path: &Path, ) -> Result { let slice_model = materialize_slice_model(model, slice_points, traced_shapes, traced_types, slice_idx)?; if let Some(parent) = output_path.parent() { std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?; } onnx_proto::save_model(&slice_model, output_path)?; Ok(output_path.to_path_buf()) } pub fn ensure_slice_materialized( slices_dir: &Path, metadata: &ModelMetadata, slice_idx: usize, ) -> Result { let slice_meta = metadata .slices .iter() .find(|s| s.index == slice_idx) .ok_or_else(|| DsperseError::Slicer(format!("no slice metadata for index {slice_idx}")))?; let slice_dir = slices_dir.join(format!("slice_{slice_idx}")); let payload_dir = slice_dir.join("payload"); let onnx_path = payload_dir.join(format!("slice_{slice_idx}.onnx")); if onnx_path.exists() { materialize_tiling_artifacts(slices_dir, metadata, slice_meta, slice_idx)?; return Ok(onnx_path); } if !slice_dir.exists() { let archive = slices_dir.join(format!("slice_{slice_idx}.dslice")); if archive.exists() { extract_dslice_archive(&archive, &slice_dir)?; if onnx_path.exists() { materialize_tiling_artifacts(slices_dir, metadata, slice_meta, slice_idx)?; return Ok(onnx_path); } } } let traced_shapes = metadata.traced_shapes.as_ref().ok_or_else(|| { DsperseError::Slicer("metadata missing traced_shapes for materialization".into()) })?; let empty_types: HashMap = HashMap::new(); let traced_types = metadata.traced_types.as_ref().unwrap_or(&empty_types); let original_path = metadata.original_model_path.as_ref().ok_or_else(|| { DsperseError::Slicer("metadata missing original_model_path for materialization".into()) })?; let model_path = if Path::new(original_path).is_absolute() { PathBuf::from(original_path) } else { slices_dir.join(original_path) }; let mut model = onnx_proto::load_model(&model_path)?; onnx_proto::normalize_opset(&mut model); let model_with_shapes = apply_traced_shapes(model, traced_shapes); std::fs::create_dir_all(&payload_dir).map_err(|e| DsperseError::io(e, &payload_dir))?; materialize_slice_to_disk( &model_with_shapes, &metadata.slice_points, traced_shapes, traced_types, slice_idx, &onnx_path, )?; tracing::info!(slice = slice_idx, path = %onnx_path.display(), "materialized slice"); materialize_tiling_artifacts(slices_dir, metadata, slice_meta, slice_idx)?; Ok(onnx_path) } fn materialize_tiling_artifacts( slices_dir: &Path, metadata: &ModelMetadata, slice_meta: &crate::schema::metadata::SliceMetadata, slice_idx: usize, ) -> Result<()> { let payload_dir = slices_dir .join(format!("slice_{slice_idx}")) .join("payload"); if let Some(ref tiling) = slice_meta.tiling && let Some(ref tile) = tiling.tile { let tile_path = slices_dir.join(&tile.path); if !tile_path.exists() { let onnx_path = payload_dir.join(format!("slice_{slice_idx}.onnx")); let slice_model = onnx_proto::load_model(&onnx_path)?; let is_ew = slice_model.graph.as_ref().is_some_and(|g| { !g.node.is_empty() && g.node.iter().all(|n| super::is_elementwise(&n.op_type)) }); let is_pool = slice_model .graph .as_ref() .is_some_and(|g| g.node.iter().any(|n| n.op_type == "MaxPool")); if is_ew { let seg_size = tiling.segment_size.ok_or_else(|| { crate::error::DsperseError::Slicer(format!( "slice {slice_idx}: elementwise tiling metadata missing segment_size; re-slice the model" )) })? as i64; autotiler::create_elementwise_tile_slice( &slice_model, seg_size, slice_idx, &payload_dir, )?; } else if is_pool { autotiler::create_pool_tile_slice( &slice_model, tiling.tile_size as i64, slice_idx, &payload_dir, )?; } else { autotiler::create_tile_slice( &slice_model, tiling.tile_size as i64, slice_idx, &payload_dir, )?; } tracing::info!(slice = slice_idx, "materialized tile ONNX"); } } if let Some(ref cs) = slice_meta.channel_split { let needs_materialization = cs.groups.is_empty() || cs.groups.iter().any(|g| { let group_path = slices_dir.join(&g.path); !group_path.exists() }); if needs_materialization && cs.num_groups > 0 { let onnx_path = payload_dir.join(format!("slice_{slice_idx}.onnx")); let slice_model = onnx_proto::load_model(&onnx_path)?; let params = ChannelSplitParams { c_in: cs.c_in as i64, c_out: cs.c_out as i64, num_groups: cs.num_groups as i64, channels_per_group: cs.channels_per_group as i64, h: cs.h as i64, w: cs.w as i64, slice_idx, }; let cs_info = autotiler::apply_channel_splitting( &slice_model, ¶ms, &cs.input_name, &cs.output_name, &payload_dir, )?; tracing::info!( slice = slice_idx, groups = cs_info.groups.len(), "materialized channel groups" ); } } if let Some(ref ds) = slice_meta.dim_split && ds.num_groups > 0 { let tmpl_path = payload_dir.join("dim_template.onnx"); if !tmpl_path.exists() { let onnx_path = payload_dir.join(format!("slice_{slice_idx}.onnx")); let slice_model = onnx_proto::load_model(&onnx_path)?; match autotiler::create_dim_split_template( &slice_model, ds, &payload_dir, metadata.traced_shapes.as_ref(), ) { Ok(_) => { tracing::info!(slice = slice_idx, "materialized dim-split template"); } Err(e) => { tracing::info!( slice = slice_idx, error = %e, "dim-split template skipped, will compile as single slice" ); } } } } Ok(()) } pub fn ensure_all_slices_materialized(slices_dir: &Path, metadata: &ModelMetadata) -> Result<()> { use rayon::prelude::*; metadata.slices.par_iter().try_for_each(|slice| { ensure_slice_materialized(slices_dir, metadata, slice.index).map(|_| ()) }) } fn apply_traced_shapes(mut model: ModelProto, shapes: &HashMap>) -> ModelProto { fn set_shape(vi: &mut ValueInfoProto, shape: &[i64]) { if let Some(ref mut tp) = vi.r#type && let Some(onnx_proto::onnx::type_proto::Value::TensorType(ref mut tt)) = tp.value { tt.shape = Some(onnx_proto::onnx::TensorShapeProto { dim: shape .iter() .map(|&d| onnx_proto::onnx::tensor_shape_proto::Dimension { denotation: String::new(), value: Some( onnx_proto::onnx::tensor_shape_proto::dimension::Value::DimValue(d), ), }) .collect(), }); } } if let Some(ref mut graph) = model.graph { for inp in &mut graph.input { if let Some(shape) = shapes.get(&inp.name) { set_shape(inp, shape); } } for out in &mut graph.output { if let Some(shape) = shapes.get(&out.name) { set_shape(out, shape); } } for vi in &mut graph.value_info { if let Some(shape) = shapes.get(&vi.name) { set_shape(vi, shape); } } let existing: HashSet = graph .input .iter() .chain(graph.output.iter()) .chain(graph.value_info.iter()) .map(|vi| vi.name.clone()) .collect(); let init_types: HashMap<&str, i32> = graph .initializer .iter() .map(|i| (i.name.as_str(), i.data_type)) .collect(); let node_output_types = build_node_output_types(graph); for (name, shape) in shapes { if !existing.contains(name) { let from_init = init_types.get(name.as_str()).copied(); let from_node = node_output_types.get(name).copied(); let elem_type = from_init.or(from_node).unwrap_or_else(|| { tracing::debug!(tensor = %name, "no explicit dtype in initializers or node outputs, assuming FLOAT"); TensorProto::FLOAT }); graph .value_info .push(onnx_proto::make_tensor_value_info(name, elem_type, shape)); } } } model } fn compute_future_dependencies( graph: &GraphProto, segment_ranges: &[(usize, usize)], init_map: &HashMap<&str, &TensorProto>, ) -> HashMap> { let mut seg_inputs: HashMap> = HashMap::new(); for (seg_idx, &(start, end)) in segment_ranges.iter().enumerate() { let seg_outputs: HashSet = graph.node[start..end] .iter() .flat_map(|n| n.output.iter().cloned()) .collect(); let inputs: HashSet = graph.node[start..end] .iter() .flat_map(|n| { if super::is_control_flow(&n.op_type) { let outer_refs = super::collect_subgraph_outer_refs(n, graph); return outer_refs .into_iter() .chain(n.input.iter().cloned()) .collect::>(); } n.input.to_vec() }) .filter(|inp| { !inp.is_empty() && !seg_outputs.contains(inp.as_str()) && !init_map.contains_key(inp.as_str()) }) .collect(); seg_inputs.insert(seg_idx, inputs); } let mut future: HashMap> = HashMap::new(); for seg_idx in 0..segment_ranges.len() { let mut deps = HashSet::new(); for future_idx in (seg_idx + 1)..segment_ranges.len() { if let Some(inputs) = seg_inputs.get(&future_idx) { deps.extend(inputs.iter().cloned()); } } future.insert(seg_idx, deps); } future } struct SegmentQuery<'a> { nodes: &'a [NodeProto], seg_outputs: &'a HashSet, seg_inputs_set: &'a HashSet, future_inputs: &'a HashSet, } struct ShapeContext<'a> { graph: &'a GraphProto, init_map: &'a HashMap<&'a str, &'a TensorProto>, vi_map: &'a HashMap, traced_shapes: &'a HashMap>, traced_types: &'a HashMap, init_types: &'a HashMap<&'a str, i32>, node_output_types: &'a HashMap, constant_producers: &'a HashMap, } impl ShapeContext<'_> { fn resolve_elem_type(&self, name: &str) -> i32 { // Resolution order: parent value_info dtype is implicitly used // upstream via vi_map; fall back to initializer dtype, then to // dtype-aware tract trace, then to the small set of ops whose // output dtype is fixed by the spec, and finally to FLOAT. // Skipping the traced_types lookup is the bug that turned // INT64 indices (TopK, Tile-of-int, Slice-of-int) into FLOAT // value_info entries; the witness path then quantised them by // alpha and produced indices like 138_149_888 = 1054 * 2^17. self.init_types .get(name) .copied() .or_else(|| self.traced_types.get(name).copied()) .or_else(|| self.node_output_types.get(name).copied()) .unwrap_or(TensorProto::FLOAT) } } fn get_segment_details( query: &SegmentQuery<'_>, ctx: &ShapeContext<'_>, ) -> Result<(Vec, Vec, Vec)> { let mut inputs = Vec::new(); let mut outputs = Vec::new(); let mut initializers = Vec::new(); let model_output_names: HashSet = ctx.graph.output.iter().map(|o| o.name.clone()).collect(); let mut added_inputs: HashSet = HashSet::new(); let mut sorted_inputs: Vec<_> = query.seg_inputs_set.iter().collect(); sorted_inputs.sort(); for inp_name in sorted_inputs { if query.seg_outputs.contains(inp_name) { continue; } if ctx.init_map.contains_key(inp_name.as_str()) { initializers.push((*ctx.init_map[inp_name.as_str()]).clone()); } else if ctx.constant_producers.contains_key(inp_name) { let mut tensor = ctx.constant_producers[inp_name].clone(); tensor.name = inp_name.clone(); initializers.push(tensor); } else if !added_inputs.contains(inp_name) { if let Some(vi) = ctx.vi_map.get(inp_name) { inputs.push((*vi).clone()); } else { let shape = ctx .traced_shapes .get(inp_name) .cloned() .or_else(|| resolve_shape_backward(inp_name, ctx.graph, ctx.traced_shapes)) .ok_or_else(|| { DsperseError::Slicer(format!( "no traced shape for segment input tensor '{inp_name}'" )) })?; inputs.push(onnx_proto::make_tensor_value_info( inp_name, ctx.resolve_elem_type(inp_name), &shape, )); } added_inputs.insert(inp_name.clone()); } } let mut sorted_outputs: Vec<_> = query.seg_outputs.iter().collect(); sorted_outputs.sort(); for out_name in sorted_outputs { if ctx.constant_producers.contains_key(out_name) { continue; } let consumed_internally = query.nodes.iter().any(|n| n.input.contains(out_name)); let needed_externally = query.future_inputs.contains(out_name) || model_output_names.contains(out_name); if !consumed_internally || needed_externally { if let Some(vi) = ctx.vi_map.get(out_name) { outputs.push((*vi).clone()); } else { let shape = ctx .traced_shapes .get(out_name) .cloned() .or_else(|| resolve_shape_backward(out_name, ctx.graph, ctx.traced_shapes)) .ok_or_else(|| { DsperseError::Slicer(format!( "no traced shape for segment output tensor '{out_name}'" )) })?; outputs.push(onnx_proto::make_tensor_value_info( out_name, ctx.resolve_elem_type(out_name), &shape, )); } } } Ok((inputs, outputs, initializers)) } pub fn build_node_output_types(graph: &GraphProto) -> HashMap { // Propagate ONNX output dtypes in topological order so that // every node's output dtype is derivable from already-resolved // input dtypes. This is the source of truth for the slicer // when tract's runtime trace falls back to f32 because it can't // statically evaluate a node (e.g. TopK with a runtime K, or // any node downstream of one that taints during tract's // best-effort eval). Without this, INT64 indices flowing // through Tile / Slice / Reshape / GatherElements get tagged // as FLOAT in slice value_info and the witness path quantises // them by alpha, producing nonsense indices like // 1054 * 2^17 = 138_149_888. let mut types: HashMap = HashMap::new(); for init in &graph.initializer { if init.data_type != 0 { types.insert(init.name.clone(), init.data_type); } } for vi in graph .input .iter() .chain(graph.value_info.iter()) .chain(graph.output.iter()) { if let Some(dt) = onnx_proto::elem_type_from_value_info(vi) && dt != 0 && !types.contains_key(&vi.name) { types.insert(vi.name.clone(), dt); } } let pass_through_first_input: &[&str] = &[ "Tile", "Slice", "Reshape", "Transpose", "Squeeze", "Unsqueeze", "Identity", "Flatten", "Expand", "Concat", "Gather", "GatherElements", "GatherND", "Pad", "Compress", "ScatterND", "ScatterElements", "Scatter", "Split", "DepthToSpace", "SpaceToDepth", "ReverseSequence", "OneHot", "Resize", "Upsample", "Crop", // Unique preserves input dtype on its first output (output[0] // values; output[1..3] indices/inverse/counts default to // INT64 but are not standard pass-through targets and the // slicer rarely sees them as graph-internal value_info // entries -- defer per-output handling until we encounter a // real model that needs it). "Unique", ]; let always_int64: &[&str] = &[ "Shape", "NonZero", "ArgMax", "ArgMin", // NonMaxSuppression's selected_indices output is INT64. "NonMaxSuppression", ]; let always_bool: &[&str] = &[ "Equal", "Less", "LessOrEqual", "Greater", "GreaterOrEqual", "And", "Or", "Not", "Xor", "IsNaN", "IsInf", ]; for node in &graph.node { match node.op_type.as_str() { "Cast" => { if let Some(to) = onnx_proto::get_attribute_int(node, "to") { for out in &node.output { if !out.is_empty() { types.insert(out.clone(), to as i32); } } } } "Constant" => { if let Some(t) = node .attribute .iter() .find(|a| a.name == "value") .and_then(|a| a.t.as_ref()) && let Some(out) = node.output.first() && !out.is_empty() { types.insert(out.clone(), t.data_type); } } "ConstantOfShape" => { let dt = node .attribute .iter() .find(|a| a.name == "value") .and_then(|a| a.t.as_ref()) .map(|t| t.data_type) .unwrap_or(TensorProto::FLOAT); for out in &node.output { if !out.is_empty() { types.insert(out.clone(), dt); } } } "MaxPool" => { if let Some(idx_out) = node.output.get(1) && !idx_out.is_empty() { types.insert(idx_out.clone(), TensorProto::INT64); } if let Some(val_out) = node.output.first() && !val_out.is_empty() && let Some(in_name) = node.input.first() && let Some(&dt) = types.get(in_name.as_str()) { types.insert(val_out.clone(), dt); } } "TopK" => { if let Some(val_out) = node.output.first() && !val_out.is_empty() && let Some(in_name) = node.input.first() && let Some(&dt) = types.get(in_name.as_str()) { types.insert(val_out.clone(), dt); } if let Some(idx_out) = node.output.get(1) && !idx_out.is_empty() { types.insert(idx_out.clone(), TensorProto::INT64); } } "Where" => { if let Some(out) = node.output.first() && !out.is_empty() { if let Some(&dt) = node.input.get(1).and_then(|n| types.get(n.as_str())) { types.insert(out.clone(), dt); } else if let Some(&dt) = node.input.get(2).and_then(|n| types.get(n.as_str())) { types.insert(out.clone(), dt); } } } op if always_int64.contains(&op) => { for out in &node.output { if !out.is_empty() { types.insert(out.clone(), TensorProto::INT64); } } } op if always_bool.contains(&op) => { for out in &node.output { if !out.is_empty() { types.insert(out.clone(), TensorProto::BOOL); } } } op if pass_through_first_input.contains(&op) => { if let Some(in_name) = node.input.first().filter(|s| !s.is_empty()) && let Some(&dt) = types.get(in_name.as_str()) { for out in &node.output { if !out.is_empty() { types.insert(out.clone(), dt); } } } } _ => {} } } types } fn extract_dslice_archive(archive: &Path, dest: &Path) -> Result<()> { let tmp_dir = dest.with_file_name(format!( ".{}.extracting.{}", dest.file_name().unwrap_or_default().to_string_lossy(), std::process::id() )); std::fs::create_dir_all(&tmp_dir).map_err(|e| DsperseError::io(e, &tmp_dir))?; let file = std::fs::File::open(archive).map_err(|e| DsperseError::io(e, archive))?; let mut zip = zip::ZipArchive::new(file).map_err(|e| { DsperseError::Slicer(format!("reading dslice archive {}: {e}", archive.display())) })?; if let Err(e) = zip.extract(&tmp_dir) { std::fs::remove_dir_all(&tmp_dir).ok(); return Err(DsperseError::Slicer(format!( "extracting {} to {}: {e}", archive.display(), tmp_dir.display() ))); } if let Err(e) = std::fs::rename(&tmp_dir, dest) { std::fs::remove_dir_all(&tmp_dir).ok(); if dest.exists() { return Ok(()); } return Err(DsperseError::Slicer(format!( "renaming {} to {}: {e}", tmp_dir.display(), dest.display() ))); } tracing::debug!(archive = %archive.display(), dest = %dest.display(), "extracted dslice archive"); Ok(()) } pub fn cleanup_extracted_slice(slices_dir: &Path, slice_id: &str) { let extract_dir = slices_dir.join(slice_id); if std::fs::remove_dir_all(&extract_dir).is_err() && extract_dir.exists() { tracing::warn!(dir = %extract_dir.display(), "failed to remove extracted slice dir"); } } ================================================ FILE: crates/dsperse/src/slicer/mod.rs ================================================ pub mod analyzer; pub mod autotiler; pub mod combiner; pub(crate) mod layernorm_fuse; pub mod materializer; pub(crate) mod onnx_fold; pub mod onnx_proto; pub(crate) mod onnx_shapes; pub mod onnx_slicer; pub(crate) mod self_div_rewrite; pub(crate) mod trace; pub use onnx_slicer::slice_model; pub(crate) const UNARY_ACTIVATIONS: &[&str] = &[ "Relu", "LeakyRelu", "PRelu", "Sigmoid", "Tanh", "Clip", "Neg", "Abs", "Sqrt", "Exp", "Log", "Sin", "Cos", "Erf", ]; pub(crate) const UNARY_STRUCTURAL: &[&str] = &["Cast", "Not", "Identity", "Dropout"]; pub(crate) const BINARY_ARITHMETIC: &[&str] = &["Add", "Sub", "Mul", "Div", "Pow", "Max", "Min"]; pub(crate) const NORMALIZATION_OPS: &[&str] = &["BatchNormalization", "Softmax", "LayerNormalization"]; pub(crate) const LAYOUT_OPS: &[&str] = &[ "Reshape", "Transpose", "Flatten", "Squeeze", "Unsqueeze", "Gather", ]; pub(crate) const CONTROL_FLOW_OPS: &[&str] = &["Loop", "If", "Scan"]; pub(crate) fn is_control_flow(op: &str) -> bool { CONTROL_FLOW_OPS.contains(&op) } pub(crate) fn collect_subgraph_outer_refs( node: &onnx_proto::NodeProto, graph: &onnx_proto::GraphProto, ) -> Vec { let mut outer_refs = Vec::new(); for attr in &node.attribute { let subgraphs: Vec<&onnx_proto::GraphProto> = attr.g.iter().chain(attr.graphs.iter()).collect(); for sg in subgraphs { collect_outer_refs_recursive(sg, graph, &mut outer_refs); } } outer_refs.sort(); outer_refs.dedup(); outer_refs } fn collect_outer_refs_recursive( subgraph: &onnx_proto::GraphProto, outer_graph: &onnx_proto::GraphProto, outer_refs: &mut Vec, ) { let local_names: std::collections::HashSet = subgraph .input .iter() .map(|vi| vi.name.clone()) .chain(subgraph.initializer.iter().map(|i| i.name.clone())) .chain(subgraph.node.iter().flat_map(|n| n.output.iter().cloned())) .collect(); let outer_names: std::collections::HashSet<&str> = outer_graph .input .iter() .map(|vi| vi.name.as_str()) .chain(outer_graph.initializer.iter().map(|i| i.name.as_str())) .chain( outer_graph .node .iter() .flat_map(|n| n.output.iter().map(|s| s.as_str())), ) .chain(outer_graph.value_info.iter().map(|vi| vi.name.as_str())) .collect(); for sg_node in &subgraph.node { for inp in &sg_node.input { if !inp.is_empty() && !local_names.contains(inp) && outer_names.contains(inp.as_str()) { outer_refs.push(inp.clone()); } } for attr in &sg_node.attribute { let nested: Vec<&onnx_proto::GraphProto> = attr.g.iter().chain(attr.graphs.iter()).collect(); for nested_sg in nested { collect_outer_refs_recursive(nested_sg, outer_graph, outer_refs); } } } } pub(crate) fn is_shape_preserving(op: &str) -> bool { UNARY_ACTIVATIONS.contains(&op) || UNARY_STRUCTURAL.contains(&op) || NORMALIZATION_OPS.contains(&op) } /// Ops the slicer may absorb into an adjacent activation slice /// without creating a new compile boundary. This is a superset of /// `is_shape_preserving`: it additionally covers the layout ops /// (Reshape / Transpose / Flatten / Squeeze / Unsqueeze / Gather) /// which CHANGE the tensor shape but do not introduce heavy /// compute -- grouping them with the producer keeps transformer /// reshape-transpose chains from shattering into N one-op slices. /// /// Do NOT reuse this for shape-fallback decisions: /// `is_shape_preserving` is consumed by the trace / materializer /// to assume `output_shape == input_shape`, which is FALSE for /// every op in LAYOUT_OPS by definition. Keep that predicate /// strict and route slicer-grouping checks through this function /// instead. pub(crate) fn is_slice_passthrough(op: &str) -> bool { is_shape_preserving(op) || LAYOUT_OPS.contains(&op) } pub(crate) fn is_elementwise(op: &str) -> bool { UNARY_ACTIVATIONS.contains(&op) || BINARY_ARITHMETIC.contains(&op) } pub(crate) fn is_binary_arithmetic(op: &str) -> bool { BINARY_ARITHMETIC.contains(&op) } pub(crate) fn build_segment_ranges( slice_points: &[usize], total_nodes: Option, ) -> Vec<(usize, usize)> { let mut points = slice_points.to_vec(); if let Some(total) = total_nodes && !points.contains(&total) { points.push(total); } points.sort(); points.dedup(); let mut ranges = Vec::new(); for i in 0..points.len() { let start = if i > 0 { points[i - 1] } else { 0 }; let end = points[i]; if start < end { ranges.push((start, end)); } } ranges } ================================================ FILE: crates/dsperse/src/slicer/onnx_fold.rs ================================================ use std::collections::{HashMap, HashSet}; use super::onnx_proto::{ GraphProto, ModelProto, NodeProto, TensorProto, tensor_to_f32, tensor_to_i64, }; pub fn fold_constant_nodes(model: &mut ModelProto) -> HashSet { let graph = match model.graph.as_mut() { Some(g) => g, None => return HashSet::new(), }; let mut folded_tensors: Vec = Vec::new(); let mut folded_names: HashSet = HashSet::new(); for node in &graph.node { if node.op_type != "Constant" { continue; } let out_name = match node.output.first() { Some(n) if !n.is_empty() => n, _ => continue, }; let tensor = match node.attribute.iter().find(|a| a.name == "value") { Some(a) => match a.t.as_ref() { Some(t) => t, None => continue, }, None => continue, }; let mut t = tensor.clone(); t.name = out_name.clone(); folded_tensors.push(t); folded_names.insert(out_name.clone()); } if folded_names.is_empty() { return folded_names; } graph .node .retain(|n| n.op_type != "Constant" || !n.output.iter().any(|o| folded_names.contains(o))); let count = folded_tensors.len(); graph.initializer.extend(folded_tensors); tracing::info!(count, "folded Constant ops into initializers"); let propagated_names = propagate_constants(graph); if !propagated_names.is_empty() { tracing::info!( propagated = propagated_names.len(), "propagated constants after Constant-node folding" ); } folded_names.extend(propagated_names); // Graph simplification runs before Conv+BN fusion so that any // Identity chain sitting between a Conv and a BatchNormalization // collapses first, exposing a contiguous Conv -> BN pattern to // the fusion pass. let identity_count = remove_identity_nodes(graph); if identity_count > 0 { tracing::info!(identity_count, "removed Identity nodes"); } let dead_count = eliminate_dead_nodes(graph); if dead_count > 0 { tracing::info!(dead_count, "eliminated dead nodes"); } let fused = fuse_conv_batchnorm(graph); if fused > 0 { tracing::info!(fused, "fused Conv+BatchNormalization pairs"); } folded_names } pub fn remove_identity_nodes(graph: &mut GraphProto) -> usize { let identity_map: HashMap = graph .node .iter() .filter(|n| n.op_type == "Identity" && n.input.len() == 1 && n.output.len() == 1) .filter(|n| !n.input[0].is_empty() && !n.output[0].is_empty()) .map(|n| (n.output[0].clone(), n.input[0].clone())) .collect(); if identity_map.is_empty() { return 0; } fn resolve(name: &str, map: &HashMap) -> String { let mut current = name; let mut visited = HashSet::new(); while let Some(target) = map.get(current) { if !visited.insert(current) { break; } current = target; } current.to_string() } let output_names: HashSet = graph.output.iter().map(|o| o.name.clone()).collect(); // Only rewire consumers whose Identity output is NOT an exported // graph output. Exported names are the model's public interface // and must survive as-is; we preserve those Identity nodes // instead of renaming the graph output. Rewriting graph.output // in place would silently change the model's API and let DCE // below remove the Identity that produces the exported tensor. let drop_map: HashMap = identity_map .iter() .filter(|(out, _)| !output_names.contains(out.as_str())) .map(|(out, inp)| (out.clone(), inp.clone())) .collect(); for node in &mut graph.node { // Skip the node that produced this drop-map entry so we // don't rewrite its own input to its own output. Guard // the output-slot access: the drop_map construction only // accepts len-1 Identity nodes, but a malformed Identity // with zero outputs could still appear in graph.node and // must not trip an index panic here. let is_dropped_identity = node.op_type == "Identity" && node .output .first() .is_some_and(|o| drop_map.contains_key(o.as_str())); if is_dropped_identity { continue; } for inp in &mut node.input { if drop_map.contains_key(inp.as_str()) { *inp = resolve(inp, &drop_map); } } } let count = drop_map.len(); graph.node.retain(|n| { !(n.op_type == "Identity" && n.output.len() == 1 && drop_map.contains_key(&n.output[0])) }); count } pub fn eliminate_dead_nodes(graph: &mut GraphProto) -> usize { let output_names: HashSet = graph.output.iter().map(|o| o.name.clone()).collect(); let mut consumed: HashSet = output_names; let mut changed = true; while changed { changed = false; for node in &graph.node { let produces_consumed = node.output.iter().any(|o| consumed.contains(o)); if produces_consumed { for inp in &node.input { if !inp.is_empty() && consumed.insert(inp.clone()) { changed = true; } } } } } let before = graph.node.len(); graph .node .retain(|n| n.output.iter().any(|o| consumed.contains(o))); let removed = before - graph.node.len(); if removed > 0 { graph.initializer.retain(|i| consumed.contains(&i.name)); graph.value_info.retain(|vi| consumed.contains(&vi.name)); } removed } pub fn propagate_constants_with_shapes( graph: &mut GraphProto, traced_shapes: &HashMap>, ) -> usize { for node in &graph.node { if node.op_type == "Shape" && let Some(inp_name) = node.input.first() && let Some(full_shape) = traced_shapes.get(inp_name) && let Some(out_name) = node.output.first() && !out_name.is_empty() && !graph.initializer.iter().any(|i| i.name == *out_name) { let ndim = full_shape.len() as i64; let start_attr = node .attribute .iter() .find(|a| a.name == "start") .map(|a| a.i) .unwrap_or(0); let end_attr = node .attribute .iter() .find(|a| a.name == "end") .map(|a| a.i) .unwrap_or(ndim); let start = if start_attr < 0 { (ndim + start_attr).max(0) as usize } else { (start_attr as usize).min(full_shape.len()) }; let end = if end_attr < 0 { (ndim + end_attr).max(0) as usize } else { (end_attr as usize).min(full_shape.len()) }; let sliced: Vec = if start < end { full_shape[start..end].to_vec() } else { vec![] }; graph.initializer.push(TensorProto { name: out_name.clone(), data_type: TensorProto::INT64, dims: vec![sliced.len() as i64], int64_data: sliced, ..Default::default() }); } } let init_names: HashSet = graph.initializer.iter().map(|i| i.name.clone()).collect(); graph .node .retain(|n| n.op_type != "Shape" || !n.output.iter().any(|o| init_names.contains(o))); let folded = propagate_constants(graph); folded.len() } pub(crate) fn propagate_constants(graph: &mut GraphProto) -> HashSet { let mut constants: HashMap = graph .initializer .iter() .map(|t| (t.name.clone(), t.clone())) .collect(); let mut folded_node_indices: HashSet = HashSet::new(); loop { let mut progress = false; for (idx, node) in graph.node.iter().enumerate() { if folded_node_indices.contains(&idx) { continue; } let inputs: Vec<&str> = node .input .iter() .filter(|s| !s.is_empty()) .map(String::as_str) .collect(); if inputs.is_empty() { continue; } if !inputs.iter().all(|name| constants.contains_key(*name)) { continue; } let input_tensors: Vec<&TensorProto> = inputs.iter().map(|n| &constants[*n]).collect(); if let Some(outputs) = eval_const_node(node, &input_tensors) { for (out_name, tensor) in outputs { constants.insert(out_name, tensor); } folded_node_indices.insert(idx); progress = true; } } if !progress { break; } } if folded_node_indices.is_empty() { return HashSet::new(); } let mut new_init_names: HashSet = HashSet::new(); for idx in &folded_node_indices { for out in &graph.node[*idx].output { if !out.is_empty() && constants.contains_key(out) { new_init_names.insert(out.clone()); } } } let mut consumed_by_remaining: HashSet = graph .node .iter() .enumerate() .filter(|(i, _)| !folded_node_indices.contains(i)) .flat_map(|(_, n)| n.input.iter().cloned()) .collect(); for node in &graph.node { if super::is_control_flow(&node.op_type) { let outer_refs = super::collect_subgraph_outer_refs(node, graph); consumed_by_remaining.extend(outer_refs); } } let output_names: HashSet = graph.output.iter().map(|o| o.name.clone()).collect(); for name in &new_init_names { if (consumed_by_remaining.contains(name) || output_names.contains(name)) && let Some(t) = constants.get(name) && !graph.initializer.iter().any(|i| i.name == *name) { graph.initializer.push(t.clone()); } } let removed_outputs: HashSet = folded_node_indices .iter() .flat_map(|idx| graph.node[*idx].output.iter().cloned()) .collect(); graph .input .retain(|vi| !removed_outputs.contains(&vi.name) || output_names.contains(&vi.name)); let count = folded_node_indices.len(); let mut kept = Vec::with_capacity(graph.node.len() - count); for (idx, node) in graph.node.drain(..).enumerate() { if !folded_node_indices.contains(&idx) { kept.push(node); } } graph.node = kept; tracing::info!(count, "propagated constant subgraphs into initializers"); new_init_names } fn eval_const_node( node: &NodeProto, inputs: &[&TensorProto], ) -> Option> { let out_name = node.output.first()?.clone(); if out_name.is_empty() { return None; } match node.op_type.as_str() { "Identity" => { let mut t = inputs[0].clone(); t.name = out_name.clone(); Some(vec![(out_name, t)]) } "Cast" => eval_cast(node, inputs[0], &out_name), "Sqrt" => eval_unary_f32(inputs[0], &out_name, f32::sqrt), "Neg" => eval_unary_f32(inputs[0], &out_name, |x| -x), "Abs" => eval_unary_f32(inputs[0], &out_name, f32::abs), "Exp" => eval_unary_f32(inputs[0], &out_name, f32::exp), "Log" => eval_unary_f32(inputs[0], &out_name, f32::ln), "Ceil" => eval_unary_f32(inputs[0], &out_name, f32::ceil), "Floor" => eval_unary_f32(inputs[0], &out_name, f32::floor), "Reciprocal" => eval_unary_f32(inputs[0], &out_name, |x| 1.0 / x), "Relu" => eval_unary_f32(inputs[0], &out_name, |x| x.max(0.0)), "Sigmoid" => eval_unary_f32(inputs[0], &out_name, |x| 1.0 / (1.0 + (-x).exp())), "Tanh" => eval_unary_f32(inputs[0], &out_name, f32::tanh), "Add" => eval_binary_f32(inputs, &out_name, |a, b| a + b), "Sub" => eval_binary_f32(inputs, &out_name, |a, b| a - b), "Mul" => eval_binary_f32(inputs, &out_name, |a, b| a * b), "Div" => eval_binary_f32(inputs, &out_name, |a, b| a / b), "Pow" => eval_binary_f32(inputs, &out_name, f32::powf), "Reshape" => eval_reshape(node, inputs, &out_name), "Squeeze" => eval_squeeze(node, inputs, &out_name), "Unsqueeze" => eval_unsqueeze(node, inputs, &out_name), "Shape" => eval_shape(node, inputs[0], &out_name), "Gather" if inputs.len() >= 2 => eval_gather(node, inputs, &out_name), "Slice" if inputs.len() >= 3 => eval_slice(inputs, &out_name), "Concat" => eval_concat(node, inputs, &out_name), "ConstantOfShape" => eval_constant_of_shape(node, inputs[0], &out_name), "Where" if inputs.len() == 3 => eval_where(inputs, &out_name), "Range" if inputs.len() == 3 => eval_range(inputs, &out_name), "Equal" => eval_cmp(inputs, &out_name, |a, b| a == b, |a, b| a == b), "Less" => eval_cmp(inputs, &out_name, |a, b| a < b, |a, b| a < b), "Greater" => eval_cmp(inputs, &out_name, |a, b| a > b, |a, b| a > b), "Not" => eval_not(inputs[0], &out_name), "And" => eval_logical(inputs, &out_name, |a, b| a & b), "Or" => eval_logical(inputs, &out_name, |a, b| a | b), "Transpose" => eval_transpose(node, inputs[0], &out_name), "ReduceMean" => eval_reduce(node, inputs, &out_name, ReduceOp::Mean), "ReduceSum" => eval_reduce(node, inputs, &out_name, ReduceOp::Sum), "ReduceMax" => eval_reduce(node, inputs, &out_name, ReduceOp::Max), "ReduceMin" => eval_reduce(node, inputs, &out_name, ReduceOp::Min), "Resize" => eval_resize(node, inputs, &out_name), "Expand" if inputs.len() == 2 => eval_expand(inputs, &out_name), "Tile" if inputs.len() == 2 => eval_tile(inputs, &out_name), "ScatterND" if inputs.len() == 3 => eval_scatter_nd(inputs, &out_name), "Split" => eval_split(node, inputs, &node.output), _ => None, } } fn eval_expand(inputs: &[&TensorProto], out_name: &str) -> Option> { let data = inputs[0]; let shape = tensor_to_i64(inputs[1]); if shape.is_empty() { return None; } let out_dims = broadcast_shape(&data.dims, &shape)?; let total = broadcast_total(&out_dims)?; if data.data_type == TensorProto::INT64 { let v = tensor_to_i64(data); if v.is_empty() { return None; } let mut result = Vec::with_capacity(total); for i in 0..total { let di = broadcast_index(i, &out_dims, &data.dims); result.push(v[di]); } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let v = tensor_to_f32(data); if v.is_empty() { return None; } let mut result = Vec::with_capacity(total); for i in 0..total { let di = broadcast_index(i, &out_dims, &data.dims); result.push(v[di]); } let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_tile(inputs: &[&TensorProto], out_name: &str) -> Option> { let data = inputs[0]; let repeats = tensor_to_i64(inputs[1]); if repeats.is_empty() || repeats.len() != data.dims.len() { return None; } let rank = data.dims.len(); let out_dims: Vec = data .dims .iter() .zip(&repeats) .map(|(&d, &r)| d * r) .collect(); let total = broadcast_total(&out_dims)?; let in_strides: Vec = { let mut s = vec![1usize; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * data.dims[i + 1] as usize; } s }; let out_strides: Vec = { let mut s = vec![1usize; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * out_dims[i + 1] as usize; } s }; if data.data_type == TensorProto::INT64 { let v = tensor_to_i64(data); if v.is_empty() { return None; } let mut result = vec![0i64; total]; for (o, out_slot) in result.iter_mut().enumerate().take(total) { let mut src = 0usize; let mut rem = o; for i in 0..rank { let coord = rem / out_strides[i]; rem %= out_strides[i]; src += (coord % data.dims[i] as usize) * in_strides[i]; } *out_slot = v[src]; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let v = tensor_to_f32(data); if v.is_empty() { return None; } let mut result = vec![0f32; total]; for (o, out_slot) in result.iter_mut().enumerate().take(total) { let mut src = 0usize; let mut rem = o; for i in 0..rank { let coord = rem / out_strides[i]; rem %= out_strides[i]; src += (coord % data.dims[i] as usize) * in_strides[i]; } *out_slot = v[src]; } let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_constant_of_shape( node: &NodeProto, shape_t: &TensorProto, out_name: &str, ) -> Option> { let dims = tensor_to_i64(shape_t); if dims.is_empty() { return None; } let total = broadcast_total(&dims)?; let (dtype, f_val, i_val) = match node.attribute.iter().find(|a| a.name == "value") { Some(a) => match a.t.as_ref() { Some(t) => { let fv = tensor_to_f32(t).first().copied().unwrap_or(0.0); let iv = tensor_to_i64(t).first().copied().unwrap_or(fv as i64); (t.data_type, fv, iv) } None => (TensorProto::FLOAT, 0.0, 0), }, None => (TensorProto::FLOAT, 0.0, 0), }; let t = match dtype { TensorProto::INT64 => TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: dims.clone(), int64_data: vec![i_val; total], ..Default::default() }, TensorProto::INT32 => TensorProto { name: out_name.to_string(), data_type: TensorProto::INT32, dims: dims.clone(), int32_data: vec![i_val as i32; total], ..Default::default() }, TensorProto::BOOL => TensorProto { name: out_name.to_string(), data_type: TensorProto::BOOL, dims: dims.clone(), int32_data: vec![(i_val != 0) as i32; total], ..Default::default() }, _ => make_f32_tensor(out_name, &dims, &vec![f_val; total], dtype), }; Some(vec![(out_name.to_string(), t)]) } fn eval_where(inputs: &[&TensorProto], out_name: &str) -> Option> { let cond = tensor_to_i64(inputs[0]); let data_type = if inputs[1].data_type == TensorProto::INT64 && inputs[2].data_type == TensorProto::INT64 { TensorProto::INT64 } else { TensorProto::FLOAT }; if data_type == TensorProto::INT64 { let x = tensor_to_i64(inputs[1]); let y = tensor_to_i64(inputs[2]); if x.is_empty() || y.is_empty() || cond.is_empty() { return None; } let xy_dims = broadcast_shape(&inputs[1].dims, &inputs[2].dims)?; let out_dims = broadcast_shape(&xy_dims, &inputs[0].dims)?; let total = broadcast_total(&out_dims)?; let mut result = Vec::with_capacity(total); for i in 0..total { let ci = broadcast_index(i, &out_dims, &inputs[0].dims); let xi = broadcast_index(i, &out_dims, &inputs[1].dims); let yi = broadcast_index(i, &out_dims, &inputs[2].dims); result.push(if cond[ci] != 0 { x[xi] } else { y[yi] }); } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let x = tensor_to_f32(inputs[1]); let y = tensor_to_f32(inputs[2]); if x.is_empty() || y.is_empty() || cond.is_empty() { return None; } let xy_dims = broadcast_shape(&inputs[1].dims, &inputs[2].dims)?; let out_dims = broadcast_shape(&xy_dims, &inputs[0].dims)?; let total = broadcast_total(&out_dims)?; let mut result = Vec::with_capacity(total); for i in 0..total { let ci = broadcast_index(i, &out_dims, &inputs[0].dims); let xi = broadcast_index(i, &out_dims, &inputs[1].dims); let yi = broadcast_index(i, &out_dims, &inputs[2].dims); result.push(if cond[ci] != 0 { x[xi] } else { y[yi] }); } let t = make_f32_tensor(out_name, &out_dims, &result, inputs[1].data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_range(inputs: &[&TensorProto], out_name: &str) -> Option> { let is_int = inputs[0].data_type == TensorProto::INT64 && inputs[1].data_type == TensorProto::INT64 && inputs[2].data_type == TensorProto::INT64; if is_int { let start = tensor_to_i64(inputs[0]).first().copied()?; let limit = tensor_to_i64(inputs[1]).first().copied()?; let delta = tensor_to_i64(inputs[2]).first().copied()?; if delta == 0 { return None; } let producing = (delta > 0 && start < limit) || (delta < 0 && start > limit); let count = if producing { let span = (limit - start) as i128; let d = delta as i128; let c = (span + d - d.signum()) / d; usize::try_from(c).ok()? } else { 0 }; if count > MAX_BROADCAST_ELEMENTS { return None; } let mut out = Vec::with_capacity(count); let mut v = start; for _ in 0..count { out.push(v); v = v.checked_add(delta)?; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: vec![out.len() as i64], int64_data: out, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let start = tensor_to_f32(inputs[0]).first().copied()?; let limit = tensor_to_f32(inputs[1]).first().copied()?; let delta = tensor_to_f32(inputs[2]).first().copied()?; if delta == 0.0 || !delta.is_finite() || !start.is_finite() || !limit.is_finite() { return None; } let count = ((limit - start) / delta).ceil(); if count <= 0.0 { let dims = vec![0i64]; let t = make_f32_tensor(out_name, &dims, &[], inputs[0].data_type); return Some(vec![(out_name.to_string(), t)]); } if count as usize > MAX_BROADCAST_ELEMENTS { return None; } let count = count as usize; let mut out = Vec::with_capacity(count); let mut v = start; for _ in 0..count { if (delta > 0.0 && v >= limit) || (delta < 0.0 && v <= limit) { break; } out.push(v); v += delta; } let dims = vec![out.len() as i64]; let t = make_f32_tensor(out_name, &dims, &out, inputs[0].data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_cmp( inputs: &[&TensorProto], out_name: &str, f_f32: fn(f32, f32) -> bool, f_i64: fn(i64, i64) -> bool, ) -> Option> { if inputs.len() < 2 { return None; } let out_dims = broadcast_shape(&inputs[0].dims, &inputs[1].dims)?; let total = broadcast_total(&out_dims)?; let both_int = inputs[0].data_type == TensorProto::INT64 && inputs[1].data_type == TensorProto::INT64; let mut result = Vec::with_capacity(total); if both_int { let a = tensor_to_i64(inputs[0]); let b = tensor_to_i64(inputs[1]); if a.is_empty() || b.is_empty() { return None; } for i in 0..total { let ai = broadcast_index(i, &out_dims, &inputs[0].dims); let bi = broadcast_index(i, &out_dims, &inputs[1].dims); result.push(f_i64(a[ai], b[bi]) as i32); } } else { let a = tensor_to_f32(inputs[0]); let b = tensor_to_f32(inputs[1]); if a.is_empty() || b.is_empty() { return None; } for i in 0..total { let ai = broadcast_index(i, &out_dims, &inputs[0].dims); let bi = broadcast_index(i, &out_dims, &inputs[1].dims); result.push(f_f32(a[ai], b[bi]) as i32); } } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::BOOL, dims: out_dims, int32_data: result, ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } fn eval_not(input: &TensorProto, out_name: &str) -> Option> { let vals = tensor_to_i64(input); if vals.is_empty() { return None; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::BOOL, dims: input.dims.clone(), int32_data: vals.iter().map(|&v| (v == 0) as i32).collect(), ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } fn eval_logical( inputs: &[&TensorProto], out_name: &str, f: fn(i32, i32) -> i32, ) -> Option> { if inputs.len() < 2 { return None; } let a = tensor_to_i64(inputs[0]); let b = tensor_to_i64(inputs[1]); if a.is_empty() || b.is_empty() { return None; } let out_dims = broadcast_shape(&inputs[0].dims, &inputs[1].dims)?; let total = broadcast_total(&out_dims)?; let mut result = Vec::with_capacity(total); for i in 0..total { let ai = broadcast_index(i, &out_dims, &inputs[0].dims); let bi = broadcast_index(i, &out_dims, &inputs[1].dims); result.push(f((a[ai] != 0) as i32, (b[bi] != 0) as i32)); } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::BOOL, dims: out_dims, int32_data: result, ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } fn eval_transpose( node: &NodeProto, input: &TensorProto, out_name: &str, ) -> Option> { let rank = input.dims.len(); if rank == 0 { return None; } let perm: Vec = match node.attribute.iter().find(|a| a.name == "perm") { Some(attr) => { if attr.ints.len() != rank { return None; } let mut out = Vec::with_capacity(rank); let mut seen = vec![false; rank]; for &raw in &attr.ints { if raw < 0 || (raw as usize) >= rank { return None; } let p = raw as usize; if seen[p] { return None; } seen[p] = true; out.push(p); } out } None => (0..rank).rev().collect(), }; let out_dims: Vec = perm.iter().map(|&p| input.dims[p]).collect(); let total = broadcast_total(&out_dims)?; let src_strides = { let mut s = vec![1i64; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * input.dims[i + 1]; } s }; let out_strides = { let mut s = vec![1i64; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * out_dims[i + 1]; } s }; let permute_index = |out_linear: usize| -> usize { let mut src = 0i64; let mut rem = out_linear as i64; for i in 0..rank { let coord = rem / out_strides[i]; rem %= out_strides[i]; src += coord * src_strides[perm[i]]; } src as usize }; if input.data_type == TensorProto::INT64 { let vals = tensor_to_i64(input); if vals.is_empty() { return None; } let mut result = Vec::with_capacity(total); for i in 0..total { result.push(vals[permute_index(i)]); } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let mut result = Vec::with_capacity(total); for i in 0..total { result.push(vals[permute_index(i)]); } let t = make_f32_tensor(out_name, &out_dims, &result, input.data_type); Some(vec![(out_name.to_string(), t)]) } #[allow(clippy::too_many_lines)] fn eval_resize( node: &NodeProto, inputs: &[&TensorProto], out_name: &str, ) -> Option> { let named: Vec<(&str, Option<&TensorProto>)> = { let mut it = inputs.iter().copied(); node.input .iter() .map(|name| { let entry = if name.is_empty() { None } else { it.next() }; (name.as_str(), entry) }) .collect() }; let x = named.first().and_then(|(_, t)| *t)?; if x.dims.len() < 2 { return None; } let rank = x.dims.len(); let vals = tensor_to_f32(x); if vals.is_empty() { return None; } let mode = node .attribute .iter() .find(|a| a.name == "mode") .map(|a| std::str::from_utf8(&a.s).unwrap_or("").to_string()) .unwrap_or_else(|| "nearest".to_string()); let ctm = node .attribute .iter() .find(|a| a.name == "coordinate_transformation_mode") .map(|a| std::str::from_utf8(&a.s).unwrap_or("").to_string()) .unwrap_or_else(|| "half_pixel".to_string()); let cubic_a = node .attribute .iter() .find(|a| a.name == "cubic_coeff_a") .map(|a| a.f) .unwrap_or(-0.75); let exclude_outside = node .attribute .iter() .find(|a| a.name == "exclude_outside") .map(|a| a.i != 0) .unwrap_or(false); let extrapolation = node .attribute .iter() .find(|a| a.name == "extrapolation_value") .map(|a| a.f) .unwrap_or(0.0); let sizes_opt = named.get(3).and_then(|(_, t)| *t).and_then(|t| { if t.dims.is_empty() || t.dims.iter().all(|&d| d == 0) { None } else { let v = tensor_to_i64(t); if v.len() == rank { Some(v) } else { None } } }); let scales_opt = named.get(2).and_then(|(_, t)| *t).and_then(|t| { if t.dims.is_empty() || t.dims.iter().all(|&d| d == 0) { None } else { let v = tensor_to_f32(t); if v.len() == rank { Some(v) } else { None } } }); let out_dims: Vec = if let Some(sizes) = sizes_opt { sizes } else if let Some(scales) = scales_opt { x.dims .iter() .zip(&scales) .map(|(&d, &s)| (d as f32 * s) as i64) .collect() } else { return None; }; let total_out = broadcast_total(&out_dims)?; let scales_eff: Vec = x .dims .iter() .zip(&out_dims) .map(|(&s, &o)| o as f32 / s as f32) .collect(); let src_stride: Vec = { let mut s = vec![1usize; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * x.dims[i + 1] as usize; } s }; let coord = |out_i: i64, d: usize| -> f32 { let out_d = out_dims[d] as f32; let in_d = x.dims[d] as f32; let s = scales_eff[d]; match ctm.as_str() { "half_pixel" => (out_i as f32 + 0.5) / s - 0.5, "pytorch_half_pixel" => { if out_d > 1.0 { (out_i as f32 + 0.5) / s - 0.5 } else { 0.0 } } "align_corners" => { if out_d > 1.0 { out_i as f32 * (in_d - 1.0) / (out_d - 1.0) } else { 0.0 } } "asymmetric" => out_i as f32 / s, _ => (out_i as f32 + 0.5) / s - 0.5, } }; let mode_kind = match mode.as_str() { "cubic" => ResizeMode::Cubic, "linear" => ResizeMode::Linear, "nearest" => ResizeMode::Nearest, _ => return None, }; let mut result = vec![0f32; total_out]; let dst_stride: Vec = { let mut s = vec![1usize; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * out_dims[i + 1] as usize; } s }; let resize_axes: Vec = (0..rank).filter(|&d| x.dims[d] != out_dims[d]).collect(); if resize_axes.is_empty() { let t = make_f32_tensor(out_name, &out_dims, &vals, x.data_type); return Some(vec![(out_name.to_string(), t)]); } if !(resize_axes.len() == 2 && resize_axes[0] + 1 == resize_axes[1] && resize_axes[1] + 1 == rank) { return None; } let h_axis = resize_axes[0]; let w_axis = resize_axes[1]; let outer_total: usize = x.dims[..h_axis].iter().map(|&d| d as usize).product(); let in_h = x.dims[h_axis] as usize; let in_w = x.dims[w_axis] as usize; let out_h = out_dims[h_axis] as usize; let out_w = out_dims[w_axis] as usize; for outer in 0..outer_total { let in_plane = outer * in_h * in_w; let out_plane = outer * out_h * out_w; for oy in 0..out_h { let sy = coord(oy as i64, h_axis); for ox in 0..out_w { let sx = coord(ox as i64, w_axis); let v = match mode_kind { ResizeMode::Nearest => { let yi = nearest_idx(sy, in_h); let xi = nearest_idx(sx, in_w); vals[in_plane + yi * in_w + xi] } ResizeMode::Linear => sample_linear_2d( &vals[in_plane..in_plane + in_h * in_w], in_h, in_w, sy, sx, exclude_outside, extrapolation, ), ResizeMode::Cubic => sample_cubic_2d( &vals[in_plane..in_plane + in_h * in_w], in_h, in_w, sy, sx, cubic_a, exclude_outside, extrapolation, ), }; result[out_plane + oy * out_w + ox] = v; } } } let _ = (src_stride, dst_stride); let t = make_f32_tensor(out_name, &out_dims, &result, x.data_type); Some(vec![(out_name.to_string(), t)]) } #[derive(Clone, Copy)] enum ResizeMode { Nearest, Linear, Cubic, } fn nearest_idx(s: f32, dim: usize) -> usize { if s < 0.0 { 0 } else { let i = s.round() as isize; if i >= dim as isize { dim - 1 } else { i as usize } } } fn sample_linear_2d( plane: &[f32], h: usize, w: usize, sy: f32, sx: f32, exclude_outside: bool, extrap: f32, ) -> f32 { let (y0_in, y0) = clamp_axis(sy.floor() as isize, h); let (y1_in, y1) = clamp_axis(sy.floor() as isize + 1, h); let (x0_in, x0) = clamp_axis(sx.floor() as isize, w); let (x1_in, x1) = clamp_axis(sx.floor() as isize + 1, w); if exclude_outside && (!y0_in && !y1_in || !x0_in && !x1_in) { return extrap; } let dy = sy - sy.floor(); let dx = sx - sx.floor(); let v00 = plane[y0 * w + x0]; let v01 = plane[y0 * w + x1]; let v10 = plane[y1 * w + x0]; let v11 = plane[y1 * w + x1]; let a = v00 * (1.0 - dx) + v01 * dx; let b = v10 * (1.0 - dx) + v11 * dx; a * (1.0 - dy) + b * dy } #[allow(clippy::too_many_arguments)] fn sample_cubic_2d( plane: &[f32], h: usize, w: usize, sy: f32, sx: f32, a_coef: f32, exclude_outside: bool, extrap: f32, ) -> f32 { let fx = sx.floor(); let fy = sy.floor(); let dx = sx - fx; let dy = sy - fy; let wx = cubic_weights(dx, a_coef); let wy = cubic_weights(dy, a_coef); let mut wx_eff = wx; let mut wy_eff = wy; if exclude_outside { for (i, w_ref) in wx_eff.iter_mut().enumerate() { let xi = fx as isize - 1 + i as isize; if xi < 0 || xi >= w as isize { *w_ref = 0.0; } } for (i, w_ref) in wy_eff.iter_mut().enumerate() { let yi = fy as isize - 1 + i as isize; if yi < 0 || yi >= h as isize { *w_ref = 0.0; } } let sx_sum: f32 = wx_eff.iter().sum(); let sy_sum: f32 = wy_eff.iter().sum(); if sx_sum == 0.0 || sy_sum == 0.0 { return extrap; } for w_ref in &mut wx_eff { *w_ref /= sx_sum; } for w_ref in &mut wy_eff { *w_ref /= sy_sum; } } let mut out = 0f32; for (iy, &wyv) in wy_eff.iter().enumerate() { if wyv == 0.0 { continue; } let yi = (fy as isize - 1 + iy as isize).clamp(0, h as isize - 1) as usize; let mut row_sum = 0f32; for (ix, &wxv) in wx_eff.iter().enumerate() { if wxv == 0.0 { continue; } let xi = (fx as isize - 1 + ix as isize).clamp(0, w as isize - 1) as usize; row_sum += plane[yi * w + xi] * wxv; } out += row_sum * wyv; } out } fn cubic_weights(t: f32, a: f32) -> [f32; 4] { let t1 = 1.0 + t; let t2 = t; let t3 = 1.0 - t; let t4 = 2.0 - t; [ cubic_kernel(t1, a), cubic_kernel(t2, a), cubic_kernel(t3, a), cubic_kernel(t4, a), ] } fn cubic_kernel(x: f32, a: f32) -> f32 { let ax = x.abs(); if ax <= 1.0 { (a + 2.0) * ax.powi(3) - (a + 3.0) * ax.powi(2) + 1.0 } else if ax < 2.0 { a * ax.powi(3) - 5.0 * a * ax.powi(2) + 8.0 * a * ax - 4.0 * a } else { 0.0 } } fn clamp_axis(i: isize, dim: usize) -> (bool, usize) { if i < 0 { (false, 0) } else if i >= dim as isize { (false, dim - 1) } else { (true, i as usize) } } #[derive(Clone, Copy)] enum ReduceOp { Sum, Mean, Max, Min, } fn eval_reduce( node: &NodeProto, inputs: &[&TensorProto], out_name: &str, op: ReduceOp, ) -> Option> { let input = inputs[0]; let rank = input.dims.len(); if rank == 0 { return None; } // Reduce* for non-floating-point tensors would lose precision // through the tensor_to_f32 path below; refuse to fold them so // the compiler can emit a proper integer reduction. if !matches!( input.data_type, TensorProto::FLOAT | TensorProto::DOUBLE | TensorProto::FLOAT16 ) { return None; } let keepdims = node .attribute .iter() .find(|a| a.name == "keepdims") .map(|a| a.i != 0) .unwrap_or(true); let axes: Vec = if inputs.len() >= 2 { tensor_to_i64(inputs[1]) } else { node.attribute .iter() .find(|a| a.name == "axes") .map(|a| a.ints.clone()) .unwrap_or_else(|| (0..rank as i64).collect()) }; let norm_axes: Vec = axes .iter() .map(|&a| { if a < 0 { (rank as i64 + a) as usize } else { a as usize } }) .collect(); for &ax in &norm_axes { if ax >= rank { return None; } } let mut out_dims_full = input.dims.clone(); for &ax in &norm_axes { out_dims_full[ax] = 1; } let out_dims: Vec = if keepdims { out_dims_full.clone() } else { out_dims_full .iter() .enumerate() .filter(|(i, _)| !norm_axes.contains(i)) .map(|(_, &d)| d) .collect() }; let total_out = broadcast_total(&out_dims_full)?; let total_in = broadcast_total(&input.dims)?; let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let reduced_count: i64 = norm_axes.iter().map(|&a| input.dims[a]).product(); let mut accum = vec![ match op { ReduceOp::Sum | ReduceOp::Mean => 0.0f32, ReduceOp::Max => f32::NEG_INFINITY, ReduceOp::Min => f32::INFINITY, }; total_out ]; for (in_idx, &v) in vals.iter().enumerate().take(total_in) { let mut rem = in_idx as i64; let mut out_idx = 0i64; let mut out_stride = 1i64; for i in (0..rank).rev() { let dim_i = input.dims[i]; let coord = rem % dim_i; rem /= dim_i; let coord_out = if norm_axes.contains(&i) { 0 } else { coord }; out_idx += coord_out * out_stride; out_stride *= out_dims_full[i]; } let o = out_idx as usize; accum[o] = match op { ReduceOp::Sum | ReduceOp::Mean => accum[o] + v, ReduceOp::Max => accum[o].max(v), ReduceOp::Min => accum[o].min(v), }; } if matches!(op, ReduceOp::Mean) && reduced_count > 0 { for a in &mut accum { *a /= reduced_count as f32; } } let t = make_f32_tensor(out_name, &out_dims, &accum, input.data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_cast( node: &NodeProto, input: &TensorProto, out_name: &str, ) -> Option> { let target_type = node .attribute .iter() .find(|a| a.name == "to") .map(|a| a.i as i32)?; match target_type { TensorProto::INT64 => { let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: input.dims.clone(), int64_data: vals.iter().map(|&v| v as i64).collect(), ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } TensorProto::INT32 => { let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT32, dims: input.dims.clone(), int32_data: vals.iter().map(|&v| v as i32).collect(), ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } TensorProto::FLOAT => { let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::FLOAT, dims: input.dims.clone(), float_data: vals, ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } TensorProto::DOUBLE => { let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::DOUBLE, dims: input.dims.clone(), double_data: vals.iter().map(|&v| v as f64).collect(), ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } TensorProto::BOOL => { let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::BOOL, dims: input.dims.clone(), int32_data: vals.iter().map(|&v| (v != 0.0) as i32).collect(), ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } _ => None, } } fn eval_unary_f32( input: &TensorProto, out_name: &str, f: fn(f32) -> f32, ) -> Option> { let vals: Vec = tensor_to_f32(input).into_iter().map(f).collect(); if vals.is_empty() { return None; } let out_type = input.data_type; let t = make_f32_tensor(out_name, &input.dims, &vals, out_type); Some(vec![(out_name.to_string(), t)]) } fn eval_binary_f32( inputs: &[&TensorProto], out_name: &str, f: fn(f32, f32) -> f32, ) -> Option> { if inputs.len() < 2 { return None; } let both_int64 = inputs[0].data_type == TensorProto::INT64 && inputs[1].data_type == TensorProto::INT64; if both_int64 { let a = tensor_to_i64(inputs[0]); let b = tensor_to_i64(inputs[1]); if a.is_empty() || b.is_empty() { return None; } let (result, dims) = broadcast_binary_i64(&a, &inputs[0].dims, &b, &inputs[1].dims, |x, y| { f(x as f32, y as f32) as i64 })?; let t = TensorProto { name: out_name.to_string(), dims, data_type: TensorProto::INT64, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let a = tensor_to_f32(inputs[0]); let b = tensor_to_f32(inputs[1]); if a.is_empty() || b.is_empty() { return None; } let (result, dims) = broadcast_binary(&a, &inputs[0].dims, &b, &inputs[1].dims, f)?; let t = make_f32_tensor(out_name, &dims, &result, TensorProto::FLOAT); Some(vec![(out_name.to_string(), t)]) } fn broadcast_shape(a_dims: &[i64], b_dims: &[i64]) -> Option> { let rank = a_dims.len().max(b_dims.len()); let mut out = Vec::with_capacity(rank); for i in 0..rank { let da = if i < rank - a_dims.len() { 1 } else { a_dims[i - (rank - a_dims.len())] }; let db = if i < rank - b_dims.len() { 1 } else { b_dims[i - (rank - b_dims.len())] }; if da == db { out.push(da); } else if da == 1 { out.push(db); } else if db == 1 { out.push(da); } else { return None; } } Some(out) } fn broadcast_index(out_idx: usize, out_dims: &[i64], src_dims: &[i64]) -> usize { let rank = out_dims.len(); let src_rank = src_dims.len(); let mut idx = 0; let mut stride = 1; for i in (0..src_rank).rev() { let out_i = rank - src_rank + i; let coord = (out_idx / out_dims[out_i + 1..].iter().product::().max(1) as usize) % out_dims[out_i] as usize; let src_coord = if src_dims[i] == 1 { 0 } else { coord }; idx += src_coord * stride; stride *= src_dims[i] as usize; } idx } const MAX_BROADCAST_ELEMENTS: usize = 100_000_000; fn broadcast_total(out_dims: &[i64]) -> Option { let mut total: usize = 1; for &d in out_dims { let d = usize::try_from(d).ok()?; total = total.checked_mul(d)?; if total > MAX_BROADCAST_ELEMENTS { return None; } } Some(total) } fn broadcast_binary( a: &[f32], a_dims: &[i64], b: &[f32], b_dims: &[i64], f: fn(f32, f32) -> f32, ) -> Option<(Vec, Vec)> { let out_dims = broadcast_shape(a_dims, b_dims)?; let total = broadcast_total(&out_dims)?; let mut result = Vec::with_capacity(total); for i in 0..total { let ai = broadcast_index(i, &out_dims, a_dims); let bi = broadcast_index(i, &out_dims, b_dims); result.push(f(a[ai], b[bi])); } Some((result, out_dims)) } fn broadcast_binary_i64( a: &[i64], a_dims: &[i64], b: &[i64], b_dims: &[i64], f: impl Fn(i64, i64) -> i64, ) -> Option<(Vec, Vec)> { let out_dims = broadcast_shape(a_dims, b_dims)?; let total = broadcast_total(&out_dims)?; let mut result = Vec::with_capacity(total); for i in 0..total { let ai = broadcast_index(i, &out_dims, a_dims); let bi = broadcast_index(i, &out_dims, b_dims); result.push(f(a[ai], b[bi])); } Some((result, out_dims)) } fn eval_reshape( node: &NodeProto, inputs: &[&TensorProto], out_name: &str, ) -> Option> { if inputs.len() < 2 { return None; } let vals = tensor_to_f32(inputs[0]); let shape = tensor_to_i64(inputs[1]); if vals.is_empty() || shape.is_empty() { return None; } let allowzero = node .attribute .iter() .find(|a| a.name == "allowzero") .map(|a| a.i != 0) .unwrap_or(false); let mut new_dims: Vec = shape .iter() .enumerate() .map(|(i, &d)| { if d == 0 { if allowzero { 0 } else { *inputs[0].dims.get(i).unwrap_or(&1) } } else { d } }) .collect(); if let Some(neg_idx) = new_dims.iter().position(|&d| d == -1) { let known: i64 = new_dims .iter() .enumerate() .filter(|&(i, &d)| i != neg_idx && d > 0) .map(|(_, &d)| d) .product(); let total: i64 = vals.len() as i64; if known > 0 { new_dims[neg_idx] = total / known; } } let t = make_f32_tensor(out_name, &new_dims, &vals, inputs[0].data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_squeeze( node: &NodeProto, inputs: &[&TensorProto], out_name: &str, ) -> Option> { let input = inputs[0]; let ndim = input.dims.len() as i64; let raw_axes: Vec = if inputs.len() >= 2 { tensor_to_i64(inputs[1]) } else { node.attribute .iter() .find(|a| a.name == "axes") .map(|a| a.ints.clone()) .unwrap_or_default() }; let axes: Vec = raw_axes .iter() .map(|&a| { if a < 0 { (ndim + a) as usize } else { a as usize } }) .collect(); if axes.is_empty() { let new_dims: Vec = input.dims.iter().copied().filter(|&d| d != 1).collect(); let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let t = make_f32_tensor(out_name, &new_dims, &vals, input.data_type); return Some(vec![(out_name.to_string(), t)]); } for &ax in &axes { if ax >= input.dims.len() || input.dims[ax] != 1 { return None; } } let new_dims: Vec = input .dims .iter() .enumerate() .filter(|(i, _)| !axes.contains(i)) .map(|(_, &d)| d) .collect(); let vals = tensor_to_f32(input); if vals.is_empty() { return None; } let t = make_f32_tensor(out_name, &new_dims, &vals, input.data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_unsqueeze( node: &NodeProto, inputs: &[&TensorProto], out_name: &str, ) -> Option> { let axes: Vec = if inputs.len() >= 2 { tensor_to_i64(inputs[1]) } else { node.attribute .iter() .find(|a| a.name == "axes") .map(|a| a.ints.clone()) .unwrap_or_default() }; let ndim = inputs[0].dims.len() + axes.len(); let mut new_dims = inputs[0].dims.clone(); let mut sorted_axes: Vec = axes .iter() .map(|&a| { if a < 0 { (ndim as i64 + a) as usize } else { a as usize } }) .collect(); sorted_axes.sort(); for &ax in &sorted_axes { if ax <= new_dims.len() { new_dims.insert(ax, 1); } } let vals = tensor_to_f32(inputs[0]); if vals.is_empty() { return None; } let t = make_f32_tensor(out_name, &new_dims, &vals, inputs[0].data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_shape( node: &NodeProto, input: &TensorProto, out_name: &str, ) -> Option> { let dims = &input.dims; if dims.is_empty() { return None; } let ndim = dims.len() as i64; let start_attr = node .attribute .iter() .find(|a| a.name == "start") .map(|a| a.i) .unwrap_or(0); let end_attr = node .attribute .iter() .find(|a| a.name == "end") .map(|a| a.i) .unwrap_or(ndim); let start = if start_attr < 0 { (ndim + start_attr).max(0) as usize } else { (start_attr as usize).min(dims.len()) }; let end = if end_attr < 0 { (ndim + end_attr).max(0) as usize } else { (end_attr as usize).min(dims.len()) }; let sliced: Vec = if start < end { dims[start..end].to_vec() } else { vec![] }; let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: vec![sliced.len() as i64], int64_data: sliced, ..Default::default() }; Some(vec![(out_name.to_string(), t)]) } fn eval_gather( node: &NodeProto, inputs: &[&TensorProto], out_name: &str, ) -> Option> { let axis = node .attribute .iter() .find(|a| a.name == "axis") .map(|a| a.i) .unwrap_or(0); let data = inputs[0]; let indices = tensor_to_i64(inputs[1]); if indices.is_empty() || data.dims.is_empty() { return None; } if data.dims.len() == 1 && axis == 0 { let data_vals = tensor_to_f32(data); if data_vals.is_empty() { let data_i64 = tensor_to_i64(data); if data_i64.is_empty() { return None; } let result: Vec = indices .iter() .map(|&i| { let idx = if i < 0 { (data.dims[0] + i) as usize } else { i as usize }; data_i64.get(idx).copied().unwrap_or(0) }) .collect(); let out_dims = if inputs[1].dims.is_empty() { vec![] } else { inputs[1].dims.clone() }; let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let result: Vec = indices .iter() .map(|&i| { let idx = if i < 0 { (data.dims[0] + i) as usize } else { i as usize }; data_vals.get(idx).copied().unwrap_or(0.0) }) .collect(); let out_dims = if inputs[1].dims.is_empty() { vec![] } else { inputs[1].dims.clone() }; let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type); return Some(vec![(out_name.to_string(), t)]); } None } fn eval_slice(inputs: &[&TensorProto], out_name: &str) -> Option> { let data = inputs[0]; let starts = tensor_to_i64(inputs[1]); let ends = tensor_to_i64(inputs[2]); if starts.is_empty() || ends.is_empty() { return None; } let axes: Vec = if inputs.len() > 3 { tensor_to_i64(inputs[3]) } else { (0..starts.len() as i64).collect() }; let steps: Vec = if inputs.len() > 4 { tensor_to_i64(inputs[4]) } else { vec![1; starts.len()] }; if starts.len() != ends.len() || axes.len() != starts.len() || steps.len() != starts.len() { return None; } let rank = data.dims.len(); if rank == 0 { return None; } let mut per_axis_range: Vec<(i64, i64, i64)> = (0..rank as i64) .map(|d| (0, data.dims[d as usize], 1)) .collect(); for (i, &raw_axis) in axes.iter().enumerate() { let a = if raw_axis < 0 { rank as i64 + raw_axis } else { raw_axis }; if a < 0 || a >= rank as i64 { return None; } let dim = data.dims[a as usize]; let step = steps[i]; if step == 0 { return None; } if dim == 0 { // Zero-length axis: any slice yields an empty output on that // axis. Record (0, 0, step) and skip clamping to avoid the // clamp(..., 0, dim - 1) == clamp(..., 0, -1) inverted range. per_axis_range[a as usize] = (0, 0, step); continue; } let raw_start = starts[i]; let raw_end = ends[i]; let clamp = |v: i64, lo: i64, hi: i64| -> i64 { v.clamp(lo, hi) }; let (s, e) = if step > 0 { // ONNX forward slice: start in [0, dim], end in [0, dim], // both treated as exclusive upper bound. let s = clamp( if raw_start < 0 { dim + raw_start } else { raw_start }, 0, dim, ); let e = clamp(if raw_end < 0 { dim + raw_end } else { raw_end }, 0, dim); (s, e) } else { // ONNX reverse slice: start in [0, dim-1] (inclusive first // read), end in [-1, dim-1] (exclusive lower bound; -1 // means "walk past index 0", i.e. include element 0). let s = clamp( if raw_start < 0 { dim + raw_start } else { raw_start }, 0, dim - 1, ); let resolved_end = if raw_end == i64::MIN { -1 } else if raw_end < 0 { dim + raw_end } else { raw_end }; let e = clamp(resolved_end, -1, dim - 1); (s, e) }; per_axis_range[a as usize] = (s, e, step); } let out_dims: Vec = per_axis_range .iter() .map(|(s, e, st)| { if *st > 0 { ((e - s + st - 1) / st).max(0) } else { ((s - e + (-st) - 1) / (-st)).max(0) } }) .collect(); let total = broadcast_total(&out_dims)?; if total == 0 { let t = TensorProto { name: out_name.to_string(), data_type: data.data_type, dims: out_dims, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let in_strides: Vec = { let mut s = vec![1i64; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * data.dims[i + 1]; } s }; let out_strides: Vec = { let mut s = vec![1i64; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * out_dims[i + 1]; } s }; let src_index = |o: i64| -> i64 { let mut rem = o; let mut src = 0i64; for d in 0..rank { let coord = rem / out_strides[d]; rem %= out_strides[d]; let (s_axis, _, st) = per_axis_range[d]; src += (s_axis + coord * st) * in_strides[d]; } src }; if data.data_type == TensorProto::INT64 { let vals = tensor_to_i64(data); if vals.is_empty() { return None; } let mut result = Vec::with_capacity(total); for o in 0..total { result.push(*vals.get(src_index(o as i64) as usize)?); } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let vals = tensor_to_f32(data); if vals.is_empty() { return None; } let mut result = Vec::with_capacity(total); for o in 0..total { result.push(*vals.get(src_index(o as i64) as usize)?); } let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_scatter_nd(inputs: &[&TensorProto], out_name: &str) -> Option> { let data = inputs[0]; let indices = inputs[1]; let updates = inputs[2]; let rank = data.dims.len(); if rank == 0 || indices.dims.is_empty() { return None; } let q = *indices.dims.last()? as usize; if q == 0 || q > rank { return None; } let total = broadcast_total(&data.dims)?; let in_strides: Vec = { let mut s = vec![1i64; rank]; for i in (0..rank.saturating_sub(1)).rev() { s[i] = s[i + 1] * data.dims[i + 1]; } s }; let trail_size: usize = data.dims[q..].iter().map(|&d| d as usize).product(); let scatter_count: usize = indices.dims[..indices.dims.len() - 1] .iter() .map(|&d| d as usize) .product(); let idx_vals = tensor_to_i64(indices); if idx_vals.len() != scatter_count * q { return None; } if data.data_type == TensorProto::INT64 { let mut buf = tensor_to_i64(data); if buf.len() != total { return None; } let upd_vals = tensor_to_i64(updates); if upd_vals.len() != scatter_count * trail_size { return None; } for s in 0..scatter_count { let mut base = 0i64; for d in 0..q { let mut idx = idx_vals[s * q + d]; if idx < 0 { idx += data.dims[d]; } if idx < 0 || idx >= data.dims[d] { return None; } base += idx * in_strides[d]; } for k in 0..trail_size { buf[base as usize + k] = upd_vals[s * trail_size + k]; } } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: data.dims.clone(), int64_data: buf, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let mut buf = tensor_to_f32(data); if buf.len() != total { return None; } let upd_vals = tensor_to_f32(updates); if upd_vals.len() != scatter_count * trail_size { return None; } for s in 0..scatter_count { let mut base = 0i64; for d in 0..q { let mut idx = idx_vals[s * q + d]; if idx < 0 { idx += data.dims[d]; } if idx < 0 || idx >= data.dims[d] { return None; } base += idx * in_strides[d]; } for k in 0..trail_size { buf[base as usize + k] = upd_vals[s * trail_size + k]; } } let t = make_f32_tensor(out_name, &data.dims, &buf, data.data_type); Some(vec![(out_name.to_string(), t)]) } fn eval_split( node: &NodeProto, inputs: &[&TensorProto], output_names: &[String], ) -> Option> { let data = inputs.first()?; let rank = data.dims.len(); if rank == 0 { return None; } let raw_axis = node .attribute .iter() .find(|a| a.name == "axis") .map(|a| a.i) .unwrap_or(0); let axis = if raw_axis < 0 { rank as i64 + raw_axis } else { raw_axis } as usize; if axis >= rank { return None; } let split_sizes: Vec = if inputs.len() >= 2 { tensor_to_i64(inputs[1]) } else if let Some(attr) = node.attribute.iter().find(|a| a.name == "split") { attr.ints.clone() } else { let n = output_names.iter().filter(|s| !s.is_empty()).count() as i64; if n == 0 { return None; } let dim = data.dims[axis]; if dim % n != 0 { return None; } vec![dim / n; n as usize] }; if split_sizes.iter().sum::() != data.dims[axis] { return None; } let outputs: Vec<&str> = output_names .iter() .filter(|s| !s.is_empty()) .map(String::as_str) .collect(); if outputs.len() != split_sizes.len() { return None; } let prefix: usize = data.dims[..axis].iter().map(|&d| d as usize).product(); let suffix: usize = data.dims[axis + 1..].iter().map(|&d| d as usize).product(); let axis_in: usize = data.dims[axis] as usize; let mut result = Vec::with_capacity(outputs.len()); let is_int64 = data.data_type == TensorProto::INT64; let mut offset = 0usize; for (i, &sz) in split_sizes.iter().enumerate() { let sz_us = usize::try_from(sz).ok()?; if sz_us == 0 { return None; } let mut out_dims = data.dims.clone(); out_dims[axis] = sz; let total = prefix * sz_us * suffix; if is_int64 { let vals = tensor_to_i64(data); if vals.is_empty() { return None; } let mut chunk = Vec::with_capacity(total); for p in 0..prefix { for ai in 0..sz_us { let src_axis = offset + ai; let src_base = (p * axis_in + src_axis) * suffix; chunk.extend_from_slice(&vals[src_base..src_base + suffix]); } } let t = TensorProto { name: outputs[i].to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: chunk, ..Default::default() }; result.push((outputs[i].to_string(), t)); } else { let vals = tensor_to_f32(data); if vals.is_empty() { return None; } let mut chunk = Vec::with_capacity(total); for p in 0..prefix { for ai in 0..sz_us { let src_axis = offset + ai; let src_base = (p * axis_in + src_axis) * suffix; chunk.extend_from_slice(&vals[src_base..src_base + suffix]); } } let t = make_f32_tensor(outputs[i], &out_dims, &chunk, data.data_type); result.push((outputs[i].to_string(), t)); } offset += sz_us; } Some(result) } fn eval_concat( node: &NodeProto, inputs: &[&TensorProto], out_name: &str, ) -> Option> { if inputs.is_empty() { return None; } let raw_axis = node .attribute .iter() .find(|a| a.name == "axis") .map(|a| a.i) .unwrap_or(0); let rank = inputs[0].dims.len(); if !inputs.iter().all(|t| t.dims.len() == rank) { return None; } if rank == 0 { return None; } let axis = if raw_axis < 0 { (rank as i64 + raw_axis) as usize } else { raw_axis as usize }; if axis >= rank { return None; } for d in 0..rank { if d == axis { continue; } let expected = inputs[0].dims[d]; if !inputs.iter().all(|t| t.dims[d] == expected) { return None; } } let mut out_dims = inputs[0].dims.clone(); out_dims[axis] = inputs.iter().map(|t| t.dims[axis]).sum(); let prefix_size: usize = out_dims[..axis].iter().map(|&d| d as usize).product(); let out_axis: usize = out_dims[axis] as usize; let suffix_size: usize = out_dims[axis + 1..].iter().map(|&d| d as usize).product(); let out_total = prefix_size .checked_mul(out_axis)? .checked_mul(suffix_size)?; if out_total > MAX_BROADCAST_ELEMENTS { return None; } // ONNX Concat requires homogeneous input element types, so the first // input's declared type is authoritative. let is_int64 = inputs[0].data_type == TensorProto::INT64; if is_int64 { let mut result: Vec = vec![0; out_total]; let mut axis_offset: usize = 0; for t in inputs { let t_vals = tensor_to_i64(t); let t_axis = t.dims[axis] as usize; if t_axis > 0 && t_vals.is_empty() { return None; } for p in 0..prefix_size { for ai in 0..t_axis { for s in 0..suffix_size { let src = (p * t_axis + ai) * suffix_size + s; let dst = (p * out_axis + axis_offset + ai) * suffix_size + s; result[dst] = t_vals[src]; } } } axis_offset += t_axis; } let t = TensorProto { name: out_name.to_string(), data_type: TensorProto::INT64, dims: out_dims, int64_data: result, ..Default::default() }; return Some(vec![(out_name.to_string(), t)]); } let mut result: Vec = vec![0.0; out_total]; let mut axis_offset: usize = 0; for t in inputs { let t_vals = tensor_to_f32(t); let t_axis = t.dims[axis] as usize; if t_axis > 0 && t_vals.is_empty() { return None; } for p in 0..prefix_size { for ai in 0..t_axis { for s in 0..suffix_size { let src = (p * t_axis + ai) * suffix_size + s; let dst = (p * out_axis + axis_offset + ai) * suffix_size + s; result[dst] = t_vals[src]; } } } axis_offset += t_axis; } let t = make_f32_tensor(out_name, &out_dims, &result, inputs[0].data_type); Some(vec![(out_name.to_string(), t)]) } fn make_f32_tensor(name: &str, dims: &[i64], vals: &[f32], target_type: i32) -> TensorProto { match target_type { TensorProto::INT64 => TensorProto { name: name.to_string(), data_type: TensorProto::INT64, dims: dims.to_vec(), int64_data: vals.iter().map(|&v| v as i64).collect(), ..Default::default() }, TensorProto::INT32 => TensorProto { name: name.to_string(), data_type: TensorProto::INT32, dims: dims.to_vec(), int32_data: vals.iter().map(|&v| v as i32).collect(), ..Default::default() }, TensorProto::DOUBLE => TensorProto { name: name.to_string(), data_type: TensorProto::DOUBLE, dims: dims.to_vec(), double_data: vals.iter().map(|&v| v as f64).collect(), ..Default::default() }, TensorProto::BOOL => TensorProto { name: name.to_string(), data_type: TensorProto::BOOL, dims: dims.to_vec(), int32_data: vals.iter().map(|&v| (v != 0.0) as i32).collect(), ..Default::default() }, _ => TensorProto { name: name.to_string(), data_type: TensorProto::FLOAT, dims: dims.to_vec(), float_data: vals.to_vec(), ..Default::default() }, } } struct ConvBnFusion { conv_idx: usize, bn_idx: usize, bn_output: String, w_name: String, bias_name: String, has_bias: bool, orig_bias: Vec, gamma: Vec, beta: Vec, mean: Vec, var: Vec, eps: f32, // Initialiser names that become dead after the fusion: the BN's // four parameter inputs (gamma / beta / running mean / running // variance) and, if the Conv had no bias before fusion, the // auto-named "_fused_bias" we create. Collected here so the // caller can purge them in a single post-pass sweep without // re-walking every BN node. stale_bn_param_names: Vec, } pub fn fuse_conv_batchnorm(graph: &mut GraphProto) -> usize { let fusions = { let init_map: HashMap<&str, &TensorProto> = graph .initializer .iter() .map(|t| (t.name.as_str(), t)) .collect(); let node_output_map: HashMap<&str, usize> = graph .node .iter() .enumerate() .flat_map(|(i, n)| n.output.iter().map(move |o| (o.as_str(), i))) .collect(); let mut fusions: Vec = Vec::new(); for (bn_idx, bn_node) in graph.node.iter().enumerate() { if bn_node.op_type != "BatchNormalization" || bn_node.input.len() < 5 { continue; } let bn_input = &bn_node.input[0]; let conv_idx = match node_output_map.get(bn_input.as_str()) { Some(&idx) => idx, None => continue, }; let conv_node = &graph.node[conv_idx]; if conv_node.op_type != "Conv" || conv_node.output.is_empty() { continue; } let consumers: usize = graph .node .iter() .filter(|n| n.input.contains(&conv_node.output[0])) .count(); if consumers != 1 { continue; } let gamma = match init_map.get(bn_node.input[1].as_str()) { Some(t) => tensor_to_f32(t), None => continue, }; let beta = match init_map.get(bn_node.input[2].as_str()) { Some(t) => tensor_to_f32(t), None => continue, }; let mean = match init_map.get(bn_node.input[3].as_str()) { Some(t) => tensor_to_f32(t), None => continue, }; let var = match init_map.get(bn_node.input[4].as_str()) { Some(t) => tensor_to_f32(t), None => continue, }; if gamma.is_empty() || gamma.len() != beta.len() || gamma.len() != mean.len() || gamma.len() != var.len() { continue; } let bn_output = match bn_node.output.first() { Some(o) if !o.is_empty() => o.clone(), _ => continue, }; let eps = bn_node .attribute .iter() .find(|a| a.name == "epsilon") .map(|a| a.f) .unwrap_or(1e-5); let w_name = conv_node.input[1].clone(); let has_bias = conv_node.input.len() > 2; let bias_name = if has_bias { conv_node.input[2].clone() } else { format!("{}_fused_bias", w_name) }; let orig_bias = if has_bias { init_map .get(conv_node.input[2].as_str()) .map(|t| tensor_to_f32(t)) .unwrap_or_default() } else { vec![] }; let stale_bn_param_names = vec![ bn_node.input[1].clone(), bn_node.input[2].clone(), bn_node.input[3].clone(), bn_node.input[4].clone(), ]; fusions.push(ConvBnFusion { conv_idx, bn_idx, bn_output, w_name, bias_name, has_bias, orig_bias, gamma, beta, mean, var, eps, stale_bn_param_names, }); } fusions }; if fusions.is_empty() { return 0; } let mut removed_bn: HashSet = HashSet::new(); let mut stale_init_names: HashSet = HashSet::new(); for f in &fusions { let channels = f.gamma.len(); let scale: Vec = (0..channels) .map(|c| f.gamma[c] / (f.var[c] + f.eps).sqrt()) .collect(); let w_ok = if let Some(w_init) = graph.initializer.iter_mut().find(|i| i.name == f.w_name) { let mut w_data = tensor_to_f32(w_init); // tensor_to_f32 returns empty for unsupported dtypes // (e.g. f16 / bf16 weights we don't yet convert); skip // the fusion for this Conv rather than silently clearing // the initializer into a zero-length FLOAT tensor that // would fail every downstream shape check. if w_data.is_empty() { false } else if !w_init.dims.is_empty() && w_init.dims[0] as usize == channels { let per_filter = w_data.len() / channels; for c in 0..channels { for j in 0..per_filter { w_data[c * per_filter + j] *= scale[c]; } } w_init.float_data = w_data; w_init.raw_data.clear(); // The initialiser may have arrived as half / bfloat // encoded in raw_data; float_data is FLOAT by // definition, so stamp the tensor metadata to match // the new representation. w_init.data_type = TensorProto::FLOAT; true } else { false } } else { false }; if !w_ok { continue; } let fused_bias: Vec = (0..channels) .map(|c| { let ob = f.orig_bias.get(c).copied().unwrap_or(0.0); (ob - f.mean[c]) * scale[c] + f.beta[c] }) .collect(); if let Some(b_init) = graph.initializer.iter_mut().find(|i| i.name == f.bias_name) { b_init.float_data = fused_bias; b_init.raw_data.clear(); b_init.dims = vec![channels as i64]; b_init.data_type = TensorProto::FLOAT; } else { graph.initializer.push(TensorProto { name: f.bias_name.clone(), data_type: TensorProto::FLOAT, dims: vec![channels as i64], float_data: fused_bias, ..Default::default() }); } let conv_node = &mut graph.node[f.conv_idx]; if !f.has_bias { conv_node.input.push(f.bias_name.clone()); } conv_node.output[0] = f.bn_output.clone(); removed_bn.insert(f.bn_idx); stale_init_names.extend(f.stale_bn_param_names.iter().cloned()); } if !removed_bn.is_empty() { let mut idx = 0; graph.node.retain(|_| { let keep = !removed_bn.contains(&idx); idx += 1; keep }); } if !stale_init_names.is_empty() { // Only drop BN parameter initialisers that no surviving node // still references. Rare in practice but cheap to verify // and prevents accidentally deleting an initialiser shared // between a fused Conv+BN and an unrelated node elsewhere. let still_used: HashSet<&str> = graph .node .iter() .flat_map(|n| n.input.iter().map(String::as_str)) .collect(); graph.initializer.retain(|init| { !stale_init_names.contains(&init.name) || still_used.contains(init.name.as_str()) }); } removed_bn.len() } ================================================ FILE: crates/dsperse/src/slicer/onnx_proto.rs ================================================ #[allow(clippy::doc_overindented_list_items)] pub mod onnx { include!(concat!(env!("OUT_DIR"), "/onnx.rs")); } use std::collections::{HashMap, HashSet}; use std::path::Path; use prost::Message; use crate::error::{DsperseError, Result}; pub use onnx::{ AttributeProto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto, TypeProto, ValueInfoProto, }; pub use super::onnx_shapes::{ elem_type_from_value_info, resolve_dynamic_input_shapes, set_vi_shape, shape_from_value_info, strip_symbolic_value_info, vi_shape, }; pub fn load_model(path: &Path) -> Result { let bytes = crate::utils::limits::read_checked(path)?; ModelProto::decode(bytes.as_slice()) .map_err(|e| DsperseError::Slicer(format!("decode {}: {e}", path.display()))) } fn canonicalize_node_attributes(nodes: &mut [NodeProto]) { for node in nodes { node.attribute.sort_by(|a, b| a.name.cmp(&b.name)); for attr in &mut node.attribute { if let Some(g) = attr.g.as_mut() { canonicalize_node_attributes(&mut g.node); } for g in &mut attr.graphs { canonicalize_node_attributes(&mut g.node); } } } } pub fn save_model(model: &ModelProto, path: &Path) -> Result<()> { let mut model = model.clone(); if let Some(graph) = model.graph.as_mut() { canonicalize_node_attributes(&mut graph.node); } for func in &mut model.functions { canonicalize_node_attributes(&mut func.node); } if let Some(parent) = path.parent() { std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?; } let bytes = model.encode_to_vec(); std::fs::write(path, bytes).map_err(|e| DsperseError::io(e, path)) } pub fn make_tensor_value_info(name: &str, elem_type: i32, shape: &[i64]) -> ValueInfoProto { ValueInfoProto { name: name.to_string(), r#type: Some(TypeProto { denotation: String::new(), value: Some(onnx::type_proto::Value::TensorType( onnx::type_proto::Tensor { elem_type, shape: Some(onnx::TensorShapeProto { dim: shape .iter() .map(|&d| onnx::tensor_shape_proto::Dimension { denotation: String::new(), value: Some(onnx::tensor_shape_proto::dimension::Value::DimValue( d, )), }) .collect(), }), }, )), }), doc_string: String::new(), metadata_props: vec![], } } pub fn make_tensor(name: &str, elem_type: i32, dims: &[i64], float_data: Vec) -> TensorProto { TensorProto { name: name.to_string(), data_type: elem_type, dims: dims.to_vec(), float_data, ..Default::default() } } pub fn make_node( op_type: &str, inputs: Vec, outputs: Vec, attributes: Vec, ) -> NodeProto { NodeProto { op_type: op_type.to_string(), input: inputs, output: outputs, attribute: attributes, name: String::new(), domain: String::new(), doc_string: String::new(), overload: String::new(), metadata_props: vec![], device_configurations: vec![], } } pub fn make_graph( name: &str, nodes: Vec, inputs: Vec, outputs: Vec, initializers: Vec, ) -> GraphProto { GraphProto { name: name.to_string(), node: nodes, input: inputs, output: outputs, initializer: initializers, ..Default::default() } } pub fn make_model(graph: GraphProto, opset_version: i64) -> ModelProto { ModelProto { ir_version: 8, graph: Some(graph), opset_import: vec![OperatorSetIdProto { domain: String::new(), version: opset_version, }], ..Default::default() } } pub fn make_attribute_ints(name: &str, ints: &[i64]) -> AttributeProto { AttributeProto { name: name.to_string(), r#type: onnx::attribute_proto::AttributeType::Ints as i32, ints: ints.to_vec(), ..Default::default() } } pub fn make_attribute_int(name: &str, val: i64) -> AttributeProto { AttributeProto { name: name.to_string(), r#type: onnx::attribute_proto::AttributeType::Int as i32, i: val, ..Default::default() } } pub fn get_attribute_ints(node: &NodeProto, name: &str) -> Option> { node.attribute .iter() .find(|a| a.name == name) .map(|a| a.ints.clone()) } pub fn get_attribute_int(node: &NodeProto, name: &str) -> Option { node.attribute.iter().find(|a| a.name == name).map(|a| a.i) } pub fn get_attribute_float(node: &NodeProto, name: &str) -> Option { node.attribute.iter().find(|a| a.name == name).map(|a| a.f) } pub fn make_attribute_float(name: &str, val: f32) -> AttributeProto { AttributeProto { name: name.to_string(), f: val, r#type: 1, ..Default::default() } } pub fn tensor_to_i64(tensor: &TensorProto) -> Vec { if !tensor.int64_data.is_empty() { return tensor.int64_data.clone(); } if !tensor.raw_data.is_empty() && tensor.data_type == TensorProto::INT64 { if !tensor.raw_data.len().is_multiple_of(8) { tracing::warn!( tensor = %tensor.name, raw_len = tensor.raw_data.len(), "misaligned INT64 raw_data, skipping" ); return Vec::new(); } return tensor .raw_data .chunks_exact(8) .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])) .collect(); } if !tensor.int32_data.is_empty() { return tensor.int32_data.iter().map(|&v| v as i64).collect(); } Vec::new() } pub fn tensor_to_f32(tensor: &TensorProto) -> Vec { if !tensor.float_data.is_empty() { return tensor.float_data.clone(); } if !tensor.raw_data.is_empty() && tensor.data_type == TensorProto::FLOAT { let chunks = tensor.raw_data.chunks_exact(4); if !chunks.remainder().is_empty() { return Vec::new(); } return chunks .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) .collect(); } if !tensor.int64_data.is_empty() { return tensor.int64_data.iter().map(|&v| v as f32).collect(); } if !tensor.int32_data.is_empty() { return tensor.int32_data.iter().map(|&v| v as f32).collect(); } if !tensor.double_data.is_empty() { return tensor.double_data.iter().map(|&v| v as f32).collect(); } if !tensor.raw_data.is_empty() { match tensor.data_type { TensorProto::INT64 => { let chunks = tensor.raw_data.chunks_exact(8); if !chunks.remainder().is_empty() { return Vec::new(); } return chunks .map(|c| { i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32 }) .collect(); } TensorProto::INT32 => { let chunks = tensor.raw_data.chunks_exact(4); if !chunks.remainder().is_empty() { return Vec::new(); } return chunks .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32) .collect(); } TensorProto::DOUBLE => { let chunks = tensor.raw_data.chunks_exact(8); if !chunks.remainder().is_empty() { return Vec::new(); } return chunks .map(|c| { f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32 }) .collect(); } _ => {} } } Vec::new() } /// Decode a TensorProto into `Vec` directly, without going /// through `f32`. FLOAT / DOUBLE / INT32 payloads — in either /// the typed `*_data` fields or the little-endian `raw_data` /// byte stream — round-trip exactly: DOUBLE keeps its full 52- /// bit mantissa, FLOAT widens losslessly, and INT32 is always /// within f64's exact-integer range. /// /// INT64 is a partial exception. f64 exactly represents every /// integer in `[-2^53, 2^53]`; INT64 magnitudes beyond 2^53 are /// rounded to the nearest representable f64 and are not /// preserved bit-for-bit. This still beats the previous /// `tensor_to_f32 -> f64::from(f32)` chain (which truncated at /// 2^24) but callers that need full INT64 fidelity must not use /// this decoder. /// /// Returns an empty `Vec` on unsupported / unrecognised dtypes /// or malformed `raw_data` length so callers can use the /// existing `data.is_empty()` skip path. pub fn tensor_to_f64(tensor: &TensorProto) -> Vec { if !tensor.double_data.is_empty() { return tensor.double_data.clone(); } if !tensor.float_data.is_empty() { return tensor.float_data.iter().map(|&v| f64::from(v)).collect(); } if !tensor.int64_data.is_empty() { #[allow(clippy::cast_precision_loss)] return tensor.int64_data.iter().map(|&v| v as f64).collect(); } if !tensor.int32_data.is_empty() { return tensor.int32_data.iter().map(|&v| f64::from(v)).collect(); } if tensor.raw_data.is_empty() { return Vec::new(); } match tensor.data_type { TensorProto::DOUBLE => { let chunks = tensor.raw_data.chunks_exact(8); if !chunks.remainder().is_empty() { return Vec::new(); } chunks .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])) .collect() } TensorProto::FLOAT => { let chunks = tensor.raw_data.chunks_exact(4); if !chunks.remainder().is_empty() { return Vec::new(); } chunks .map(|c| f64::from(f32::from_le_bytes([c[0], c[1], c[2], c[3]]))) .collect() } TensorProto::INT64 => { let chunks = tensor.raw_data.chunks_exact(8); if !chunks.remainder().is_empty() { return Vec::new(); } #[allow(clippy::cast_precision_loss)] chunks .map(|c| { i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f64 }) .collect() } TensorProto::INT32 => { let chunks = tensor.raw_data.chunks_exact(4); if !chunks.remainder().is_empty() { return Vec::new(); } chunks .map(|c| f64::from(i32::from_le_bytes([c[0], c[1], c[2], c[3]]))) .collect() } _ => Vec::new(), } } pub fn build_initializer_map(graph: &GraphProto) -> HashMap { graph .initializer .iter() .map(|i| (i.name.clone(), i)) .collect() } pub fn build_value_info_map(graph: &GraphProto) -> HashMap { let mut map: HashMap = HashMap::new(); for vi in &graph.input { map.insert(vi.name.clone(), vi); } for vi in &graph.output { map.insert(vi.name.clone(), vi); } for vi in &graph.value_info { map.insert(vi.name.clone(), vi); } map } impl TensorProto { pub const FLOAT: i32 = 1; pub const INT64: i32 = 7; pub const DOUBLE: i32 = 11; pub const INT32: i32 = 6; pub const FLOAT16: i32 = 10; pub const BOOL: i32 = 9; } fn is_paddable_shape(target: &[i64], donor: &[i64]) -> bool { if target.len() != donor.len() || target.is_empty() { return false; } let last = target.len() - 1; target[..last] == donor[..last] && donor[last] < target[last] && donor[last] > 0 } pub fn validate_initializer_compatibility( initializers: &[TensorProto], donor_init_map: &HashMap, context: &str, ) -> Result<()> { for init in initializers { if let Some(donor) = donor_init_map.get(&init.name) { if init.data_type != donor.data_type { return Err(DsperseError::Pipeline(format!( "dtype mismatch for initializer '{}' in {context}: slice has dtype {}, consumer has dtype {}", init.name, init.data_type, donor.data_type ))); } if init.dims != donor.dims { if is_paddable_shape(&init.dims, &donor.dims) { tracing::info!( name = %init.name, target = ?init.dims, donor = ?donor.dims, "donor initializer will be zero-padded on last axis" ); } else { return Err(DsperseError::Pipeline(format!( "shape mismatch for initializer '{}' in {context}: slice expects {:?}, consumer provides {:?}", init.name, init.dims, donor.dims ))); } } } else { tracing::debug!( name = %init.name, context, "initializer not in donor weights, retaining slice value" ); } } Ok(()) } fn pad_float_data( donor_data: &[f32], target_dims: &[i64], donor_dims: &[i64], pad_val: f32, ) -> Vec { let last = target_dims.len() - 1; let target_last = target_dims[last] as usize; let donor_last = donor_dims[last] as usize; let rows = donor_data.len() / donor_last.max(1); let mut padded = Vec::with_capacity(rows * target_last); for row in 0..rows { let start = row * donor_last; let end = start + donor_last; padded.extend_from_slice(&donor_data[start..end.min(donor_data.len())]); padded.resize(padded.len() + (target_last - donor_last), pad_val); } padded } fn pad_raw_data_f32(raw: &[u8], target_dims: &[i64], donor_dims: &[i64], pad_val: f32) -> Vec { let donor_floats: Vec = raw .chunks_exact(4) .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) .collect(); let padded = pad_float_data(&donor_floats, target_dims, donor_dims, pad_val); padded.iter().flat_map(|f| f.to_le_bytes()).collect() } pub fn replace_initializers( model: &mut ModelProto, donor_init_map: &HashMap, ) -> Result { let graph = model .graph .as_mut() .ok_or_else(|| DsperseError::Pipeline("ONNX model missing graph".into()))?; let mut replaced = 0; for init in &mut graph.initializer { if let Some(donor) = donor_init_map.get(&init.name) { if init.data_type != donor.data_type { return Err(DsperseError::Pipeline(format!( "dtype mismatch for initializer '{}' in replace_initializers: slice has dtype {}, consumer has dtype {}", init.name, init.data_type, donor.data_type ))); } let needs_pad = init.dims != donor.dims && is_paddable_shape(&init.dims, &donor.dims); if init.dims != donor.dims && !needs_pad { return Err(DsperseError::Pipeline(format!( "shape mismatch for initializer '{}' in replace_initializers: slice expects {:?}, consumer provides {:?}", init.name, init.dims, donor.dims ))); } if needs_pad { let is_bias = donor.dims.len() == 1; let pad_val: f32 = if is_bias { -10.0 } else { 0.0 }; if !donor.float_data.is_empty() { init.float_data = pad_float_data(&donor.float_data, &init.dims, &donor.dims, pad_val); init.raw_data.clear(); } else if !donor.raw_data.is_empty() && donor.data_type == TensorProto::FLOAT { init.raw_data = pad_raw_data_f32(&donor.raw_data, &init.dims, &donor.dims, pad_val); init.float_data.clear(); } tracing::info!( name = %init.name, from = ?donor.dims, to = ?init.dims, "padded donor initializer" ); } else { init.float_data = donor.float_data.clone(); init.raw_data = donor.raw_data.clone(); init.double_data = donor.double_data.clone(); init.int32_data = donor.int32_data.clone(); init.int64_data = donor.int64_data.clone(); } replaced += 1; } } Ok(replaced) } pub fn build_patched_onnx( slice_onnx: &Path, donor_init_map: &HashMap, ) -> Result { let mut model = load_model(slice_onnx)?; replace_initializers(&mut model, donor_init_map)?; let tmp = tempfile::NamedTempFile::with_suffix(".onnx") .map_err(|e| DsperseError::Pipeline(format!("create temp file: {e}")))?; save_model(&model, tmp.path())?; Ok(tmp) } fn model_opset_version(model: &ModelProto) -> i64 { model .opset_import .iter() .find(|o| o.domain.is_empty() || o.domain == "ai.onnx") .map(|o| o.version) .unwrap_or(1) } fn min_opset_for_op(op_type: &str) -> Option { match op_type { "GridSample" => Some(16), "ScatterND" => Some(16), "ScatterElements" => Some(16), _ => None, } } pub fn normalize_opset(model: &mut ModelProto) -> usize { let opset = model_opset_version(model); if opset < 13 { return 0; } let graph = match model.graph.as_mut() { Some(g) => g, None => return 0, }; let mut required_opset = opset; for node in graph.node.iter() { if let Some(min) = min_opset_for_op(&node.op_type) { required_opset = required_opset.max(min); } } let mut new_initializers: Vec = Vec::new(); let mut count = 0; for node in &mut graph.node { match node.op_type.as_str() { "Unsqueeze" | "Squeeze" if node.input.len() == 1 => { if let Some(axes) = get_attribute_ints(node, "axes") { let axes_name = format!("{}_axes_const", node.name); new_initializers.push(TensorProto { name: axes_name.clone(), data_type: TensorProto::INT64, dims: vec![axes.len() as i64], int64_data: axes, ..Default::default() }); node.input.push(axes_name); node.attribute.retain(|a| a.name != "axes"); count += 1; } } "Reshape" if opset < 14 => { let had = node.attribute.iter().any(|a| a.name == "allowzero"); if had { node.attribute.retain(|a| a.name != "allowzero"); count += 1; } } _ => {} } } graph.initializer.extend(new_initializers); if required_opset > opset { if let Some(entry) = model .opset_import .iter_mut() .find(|o| o.domain.is_empty() || o.domain == "ai.onnx") { entry.version = required_opset; } tracing::info!( from = opset, to = required_opset, "bumped declared opset to match op requirements" ); count += 1; } if count > 0 { tracing::info!( opset = required_opset, fixes = count, "normalized ONNX opset conventions" ); } count } pub fn normalize_for_circuit_backend(model: &mut ModelProto) -> usize { let graph = match model.graph.as_mut() { Some(g) => g, None => return 0, }; let folded_names = super::onnx_fold::propagate_constants(graph); let folded = folded_names.len(); let fixed = fix_zero_dims(graph); let count = flatten_matmul_inputs(graph) + materialize_reshape_targets(graph) + fixed; let total = folded + count; if total > 0 { tracing::info!( total, folded, "normalized graph for circuit backend compatibility" ); } total } fn fix_zero_dims(graph: &mut GraphProto) -> usize { let mut shapes: HashMap> = HashMap::new(); for inp in &graph.input { if let Some(s) = shape_from_value_info(inp) && s.iter().all(|&d| d > 0) { shapes.insert(inp.name.clone(), s); } } for init in &graph.initializer { if !init.dims.is_empty() { shapes.insert(init.name.clone(), init.dims.clone()); } } for vi in &graph.value_info { if let Some(s) = shape_from_value_info(vi) && s.iter().all(|&d| d > 0) && !shapes.contains_key(&vi.name) { shapes.insert(vi.name.clone(), s); } } let mut count = 0; for vi in graph.value_info.iter_mut().chain(graph.output.iter_mut()) { if let Some(new_shape) = shapes.get(&vi.name) && let Some(existing) = shape_from_value_info(vi) && existing.contains(&0) { set_vi_shape(vi, new_shape); count += 1; } } if count > 0 { tracing::info!(count, "resolved zero-valued placeholder dimensions"); } count } fn flatten_matmul_inputs(graph: &mut GraphProto) -> usize { let vi_shapes: HashMap> = graph .input .iter() .chain(graph.value_info.iter()) .chain(graph.output.iter()) .filter_map(|vi| shape_from_value_info(vi).map(|s| (vi.name.clone(), s))) .collect(); let init_shapes: HashMap> = graph .initializer .iter() .map(|i| (i.name.clone(), i.dims.clone())) .collect(); let shapes: HashMap> = vi_shapes.into_iter().chain(init_shapes).collect(); let elem_types: HashMap = graph .input .iter() .chain(graph.value_info.iter()) .chain(graph.output.iter()) .filter_map(|vi| elem_type_from_value_info(vi).map(|t| (vi.name.clone(), t))) .chain( graph .initializer .iter() .map(|i| (i.name.clone(), i.data_type)), ) .collect(); let mut new_nodes: Vec<(usize, Vec)> = Vec::new(); let mut new_inits: Vec = Vec::new(); let mut new_vis: Vec = Vec::new(); let mut count = 0; for (idx, node) in graph.node.iter().enumerate() { if node.op_type != "MatMul" { continue; } let a_name = match node.input.first() { Some(n) if !n.is_empty() => n, _ => continue, }; let b_name = match node.input.get(1) { Some(n) if !n.is_empty() => n, _ => continue, }; let a_shape = match shapes.get(a_name) { Some(s) if s.len() > 3 => s.clone(), _ => continue, }; let b_shape = match shapes.get(b_name) { Some(s) => s.clone(), None => continue, }; let out_name = match node.output.first() { Some(n) if !n.is_empty() => n.clone(), _ => continue, }; let batch_dims = &a_shape[..a_shape.len() - 2]; let batch_vol: i64 = batch_dims.iter().product(); let m = a_shape[a_shape.len() - 2]; let k = a_shape[a_shape.len() - 1]; let node_tag = if node.name.is_empty() { format!("matmul_{idx}") } else { node.name.clone() }; let a_2d_name = format!("{a_name}__flat2d_{node_tag}"); let a_2d_shape_name = format!("{a_name}__flat2d_shape_{node_tag}"); let a_2d = vec![batch_vol * m, k]; let mut b_2d_name = b_name.clone(); let mut needs_b_reshape = false; let n_dim; if b_shape.len() > 2 { let b_m = b_shape[b_shape.len() - 2]; n_dim = b_shape[b_shape.len() - 1]; let b_batch: i64 = b_shape[..b_shape.len() - 2].iter().product(); if b_batch == 1 { b_2d_name = format!("{b_name}__flat2d_{node_tag}"); let b_2d_shape_name = format!("{b_name}__flat2d_shape_{node_tag}"); let b_2d = vec![b_batch * b_m, n_dim]; new_inits.push(TensorProto { name: b_2d_shape_name.clone(), data_type: TensorProto::INT64, dims: vec![2], int64_data: b_2d.clone(), ..Default::default() }); let b_elem = elem_types .get(b_name) .copied() .unwrap_or(TensorProto::FLOAT); new_vis.push(make_tensor_value_info(&b_2d_name, b_elem, &b_2d)); needs_b_reshape = true; } } else { n_dim = *b_shape.last().unwrap_or(&1); } let matmul_out_name = format!("{out_name}__matmul2d_{node_tag}"); let matmul_2d_shape = vec![batch_vol * m, n_dim]; let restore_shape_name = format!("{out_name}__restore_shape_{node_tag}"); let mut restored: Vec = batch_dims.to_vec(); restored.push(m); if b_shape.len() > 1 { restored.push(n_dim); } new_inits.push(TensorProto { name: a_2d_shape_name.clone(), data_type: TensorProto::INT64, dims: vec![2], int64_data: a_2d.clone(), ..Default::default() }); new_inits.push(TensorProto { name: restore_shape_name.clone(), data_type: TensorProto::INT64, dims: vec![restored.len() as i64], int64_data: restored.clone(), ..Default::default() }); let a_elem = elem_types .get(a_name) .copied() .unwrap_or(TensorProto::FLOAT); new_vis.push(make_tensor_value_info(&a_2d_name, a_elem, &a_2d)); new_vis.push(make_tensor_value_info( &matmul_out_name, a_elem, &matmul_2d_shape, )); let mut inserted = Vec::new(); inserted.push(NodeProto { op_type: "Reshape".into(), name: format!("{}_flatten_a", node.name), input: vec![a_name.clone(), a_2d_shape_name], output: vec![a_2d_name.clone()], ..Default::default() }); if needs_b_reshape { let b_2d_shape_name = format!("{b_name}__flat2d_shape_{node_tag}"); inserted.push(NodeProto { op_type: "Reshape".into(), name: format!("{}_flatten_b", node.name), input: vec![b_name.clone(), b_2d_shape_name], output: vec![b_2d_name.clone()], ..Default::default() }); } inserted.push(NodeProto { op_type: "MatMul".into(), name: node.name.clone(), input: vec![a_2d_name, b_2d_name], output: vec![matmul_out_name.clone()], attribute: node.attribute.clone(), ..Default::default() }); inserted.push(NodeProto { op_type: "Reshape".into(), name: format!("{}_restore", node.name), input: vec![matmul_out_name, restore_shape_name], output: vec![out_name], ..Default::default() }); new_nodes.push((idx, inserted)); count += 1; } let mut cumulative_offset: usize = 0; for (idx, nodes) in new_nodes { let pos = idx + cumulative_offset; graph.node.remove(pos); let inserted = nodes.len(); for (i, n) in nodes.into_iter().enumerate() { graph.node.insert(pos + i, n); } cumulative_offset += inserted - 1; } graph.initializer.extend(new_inits); graph.value_info.extend(new_vis); count } fn materialize_reshape_targets(graph: &mut GraphProto) -> usize { let mut init_names: HashSet = graph.initializer.iter().map(|i| i.name.clone()).collect(); let input_names: HashSet = graph.input.iter().map(|i| i.name.clone()).collect(); let produced_names: HashSet = graph .node .iter() .flat_map(|n| n.output.iter().cloned()) .collect(); let vi_shapes: HashMap> = graph .value_info .iter() .chain(graph.output.iter()) .filter_map(|vi| shape_from_value_info(vi).map(|s| (vi.name.clone(), s))) .collect(); let mut new_inits: Vec = Vec::new(); let mut count = 0; for node in &graph.node { if node.op_type != "Reshape" { continue; } let shape_input = match node.input.get(1) { Some(n) if !n.is_empty() => n, _ => continue, }; if init_names.contains(shape_input) || input_names.contains(shape_input) || produced_names.contains(shape_input) { continue; } let out_name = match node.output.first() { Some(n) if !n.is_empty() => n, _ => continue, }; let out_shape = match vi_shapes.get(out_name) { Some(s) if !s.is_empty() && s.iter().all(|&d| d > 0) => s, _ => continue, }; new_inits.push(TensorProto { name: shape_input.clone(), data_type: TensorProto::INT64, dims: vec![out_shape.len() as i64], int64_data: out_shape.clone(), ..Default::default() }); init_names.insert(shape_input.clone()); count += 1; } graph.initializer.extend(new_inits); count } ================================================ FILE: crates/dsperse/src/slicer/onnx_shapes.rs ================================================ use super::onnx_proto::{ModelProto, ValueInfoProto, onnx}; pub fn shape_from_value_info(vi: &ValueInfoProto) -> Option> { let tp = vi.r#type.as_ref()?; let onnx::type_proto::Value::TensorType(tensor) = tp.value.as_ref()? else { return None; }; let shape_proto = tensor.shape.as_ref()?; let mut dims = Vec::new(); for d in &shape_proto.dim { match &d.value { Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => dims.push(*v), _ => return None, } } Some(dims) } pub fn elem_type_from_value_info(vi: &ValueInfoProto) -> Option { let tp = vi.r#type.as_ref()?; let onnx::type_proto::Value::TensorType(tensor) = tp.value.as_ref()? else { return None; }; Some(tensor.elem_type) } pub fn vi_shape(vi: &ValueInfoProto) -> Vec { vi.r#type .as_ref() .and_then(|t| match &t.value { Some(onnx::type_proto::Value::TensorType(tt)) => tt.shape.as_ref(), _ => None, }) .map(|s| { s.dim .iter() .map(|d| match &d.value { Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => *v, _ => 0, }) .collect() }) .unwrap_or_default() } pub fn set_vi_shape(vi: &mut ValueInfoProto, shape: &[i64]) { if let Some(ref mut tp) = vi.r#type && let Some(onnx::type_proto::Value::TensorType(ref mut tt)) = tp.value { tt.shape = Some(onnx::TensorShapeProto { dim: shape .iter() .map(|&d| onnx::tensor_shape_proto::Dimension { denotation: String::new(), value: Some(onnx::tensor_shape_proto::dimension::Value::DimValue(d)), }) .collect(), }); } } pub fn strip_symbolic_value_info(model: &mut ModelProto) -> usize { let graph = match model.graph.as_mut() { Some(g) => g, None => return 0, }; let has_symbolic = |vi: &ValueInfoProto| -> bool { vi.r#type .as_ref() .and_then(|t| match &t.value { Some(onnx::type_proto::Value::TensorType(tt)) => tt.shape.as_ref(), _ => None, }) .is_some_and(|s| { s.dim.iter().any(|d| { matches!( &d.value, Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) ) }) }) }; let before = graph.value_info.len(); graph.value_info.retain(|vi| !has_symbolic(vi)); let removed = before - graph.value_info.len(); for out in &mut graph.output { if let Some(ref mut tp) = out.r#type && let Some(onnx::type_proto::Value::TensorType(ref mut tt)) = tp.value && let Some(ref mut shape) = tt.shape { for d in &mut shape.dim { if matches!( &d.value, Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) ) { d.value = None; } } } } if removed > 0 { tracing::info!( removed, "stripped value_info entries with symbolic dimensions" ); } removed } pub fn resolve_dynamic_input_shapes( model: &mut ModelProto, explicit_shape: Option<&[i64]>, ) -> crate::error::Result { let graph = match model.graph.as_mut() { Some(g) => g, None => return Ok(0), }; let has_non_batch_symbolic = |inp: &&ValueInfoProto| -> bool { inp.r#type .as_ref() .and_then(|t| match &t.value { Some(onnx::type_proto::Value::TensorType(tt)) => tt.shape.as_ref(), _ => None, }) .is_some_and(|s| { s.dim.iter().skip(1).any(|d| { matches!( &d.value, Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None ) }) }) }; let symbolic_count = graph.input.iter().filter(has_non_batch_symbolic).count(); if symbolic_count > 1 && explicit_shape.is_some() { return Err(crate::error::DsperseError::Slicer(format!( "model has {symbolic_count} inputs with non-batch dynamic dimensions; \ --input-shape applies to a single input. Per-input shapes not yet supported." ))); } let mut resolved = 0; for inp in &mut graph.input { let tp = match inp.r#type.as_mut() { Some(t) => t, None => continue, }; let tensor = match &mut tp.value { Some(onnx::type_proto::Value::TensorType(tt)) => tt, _ => continue, }; let shape = match tensor.shape.as_mut() { Some(s) => s, None => continue, }; let has_symbolic = shape.dim.iter().any(|d| { matches!( &d.value, Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None ) }); if !has_symbolic { continue; } if let Some(explicit) = explicit_shape { if explicit.len() != shape.dim.len() { return Err(crate::error::DsperseError::Slicer(format!( "input '{}' has rank {} but --input-shape provides {} dims", inp.name, shape.dim.len(), explicit.len() ))); } for (d, &v) in shape.dim.iter_mut().zip(explicit.iter()) { if let Some(onnx::tensor_shape_proto::dimension::Value::DimValue(existing)) = &d.value { if *existing != v { return Err(crate::error::DsperseError::Slicer(format!( "input '{}': --input-shape dim {} conflicts with fixed dim {}", inp.name, v, existing ))); } } else { d.value = Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)); } } tracing::info!(input = %inp.name, shape = ?explicit, "applied explicit input shape"); resolved += 1; continue; } let non_batch_symbolic = shape.dim.iter().skip(1).any(|d| { matches!( &d.value, Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None ) }); if non_batch_symbolic { let dim_names: Vec = shape .dim .iter() .map(|d| match &d.value { Some(onnx::tensor_shape_proto::dimension::Value::DimParam(s)) => s.clone(), Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => v.to_string(), None => "?".into(), }) .collect(); return Err(crate::error::DsperseError::Slicer(format!( "model input '{}' has dynamic dimensions [{}]; provide --input-shape to set concrete values", inp.name, dim_names.join(", ") ))); } shape.dim[0].value = Some(onnx::tensor_shape_proto::dimension::Value::DimValue(1)); tracing::info!(input = %inp.name, "defaulted batch dimension to 1"); resolved += 1; } Ok(resolved) } ================================================ FILE: crates/dsperse/src/slicer/onnx_slicer.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::Path; use super::analyzer::{self, AnalysisResult, NodeAnalysis}; use super::autotiler; use super::materializer; use super::onnx_proto; use crate::error::{DsperseError, Result}; use crate::schema::metadata::{ Dependencies, ModelMetadata, SliceMetadata, SliceShapeWrapper, TensorShape, }; use crate::schema::tiling::DimSplitInfo; pub fn slice_model( onnx_path: &Path, output_path: Option<&Path>, tile_size: Option, jstprove_ops: &[&str], input_shape: Option<&[i64]>, ) -> Result { let mut model = onnx_proto::load_model(onnx_path)?; onnx_proto::normalize_opset(&mut model); onnx_proto::resolve_dynamic_input_shapes(&mut model, input_shape)?; onnx_proto::strip_symbolic_value_info(&mut model); let folded_constants = super::onnx_fold::fold_constant_nodes(&mut model); let tmp_dir = tempfile::tempdir().map_err(|e| DsperseError::io(e, onnx_path))?; let tract_path = tmp_dir.path().join("tract_model.onnx"); onnx_proto::save_model(&model, &tract_path)?; tracing::info!("folding constants and tracing shapes via tract"); let trace_result = super::trace::fold_and_trace_via_tract(&tract_path, &model)?; let mut traced_shapes = trace_result.shapes; let traced_types = trace_result.types; if let Some(graph) = model.graph.as_mut() { // Chains of shape-dependent ops (Shape -> Gather -> Reshape, // or nested ConstantOfShape pyramids) expose constants only // after earlier rounds have folded their producers, so run // propagate_constants_with_shapes to a fixpoint. A small // safety cap prevents an unexpected non-monotonic evaluator // from spinning indefinitely; propagation is monotone by // construction so the loop is expected to converge in O(1) // iterations even for the deepest chains we have observed. const SHAPE_CONST_PROP_ITERATION_CAP: usize = 16; let mut total_folded = 0usize; for pass in 0..SHAPE_CONST_PROP_ITERATION_CAP { let folded = super::onnx_fold::propagate_constants_with_shapes(graph, &traced_shapes); if folded == 0 { break; } total_folded += folded; tracing::info!(pass, folded, "shape-constant propagation pass"); } if total_folded > 0 { tracing::info!( total_folded, "propagated shape-derived constants in parent graph" ); } } let fused_ln = super::layernorm_fuse::fuse_inline_layernorms(&mut model, &mut traced_shapes); if fused_ln > 0 { tracing::info!(fused_ln, "fused inline LayerNorm patterns"); } let self_div_rewrites = super::self_div_rewrite::rewrite_self_div_to_one(&mut model, &mut traced_shapes); if self_div_rewrites > 0 { tracing::info!(self_div_rewrites, "rewrote degenerate Div(X, X) nodes"); } let missing: Vec = if let Some(graph) = &model.graph { let mut missing = Vec::new(); for n in &graph.node { for out in &n.output { if !out.is_empty() && !traced_shapes.contains_key(out) { missing.push(out.clone()); } } } missing } else { Vec::new() }; if !missing.is_empty() { tracing::warn!(count = missing.len(), first_few = ?&missing[..missing.len().min(5)], "unresolved tensor shapes after all inference passes"); } let analysis = analyzer::analyze(&model, Some(onnx_path))?; let output_dir = output_path.map(|p| p.to_path_buf()).unwrap_or_else(|| { onnx_path .parent() .unwrap_or_else(|| Path::new(".")) .join("slices") }); std::fs::create_dir_all(&output_dir).map_err(|e| DsperseError::io(e, &output_dir))?; let slice_points = determine_slice_points(&analysis, tile_size, jstprove_ops, &model, &traced_shapes); tracing::info!(points = ?slice_points, "determined slice points"); debug_assert!( !slice_points.is_empty(), "complete_slice_points guarantees at least [0, end]" ); let model_dest = output_dir.join("model.onnx"); onnx_proto::save_model(&model, &model_dest)?; let segment_ranges = super::build_segment_ranges(&slice_points, None); let trimmed_points = &slice_points[..slice_points.len().saturating_sub(1)]; let mut tiled_info = HashMap::new(); let mut dim_split_info: HashMap)> = HashMap::new(); for (seg_idx, _) in segment_ranges.iter().enumerate() { let slice_model = materializer::materialize_slice_model( &model, trimmed_points, &traced_shapes, &traced_types, seg_idx, )?; if let Some(detection) = autotiler::detect_tiling_needs(&slice_model, tile_size) { tiled_info.insert(seg_idx, detection); continue; } if let Some(graph) = slice_model.graph.as_ref() { let init_names: HashSet = graph.initializer.iter().map(|t| t.name.clone()).collect(); let mut slice_shapes: HashMap> = HashMap::new(); for vi in graph .input .iter() .chain(graph.output.iter()) .chain(graph.value_info.iter()) { let dims = onnx_proto::vi_shape(vi); if !dims.is_empty() { slice_shapes.insert(vi.name.clone(), dims); } } for init in &graph.initializer { slice_shapes .entry(init.name.clone()) .or_insert_with(|| init.dims.to_vec()); } for (name, shape) in &traced_shapes { slice_shapes .entry(name.clone()) .or_insert_with(|| shape.clone()); } if let Some(detection) = autotiler::detect_dim_split( &graph.node, &slice_shapes, &init_names, autotiler::model_opset(&model), ) { // Build a tentative DimSplitInfo to attempt template creation. // Only record the detection if the template materializes // successfully, so the metadata never carries dim_split // entries that can't be fulfilled at runtime. let tentative_info = DimSplitInfo::from_detection(&detection, seg_idx, None); let slice_dir = output_dir.join(format!("slice_{seg_idx}")).join("payload"); std::fs::create_dir_all(&slice_dir).map_err(|e| DsperseError::io(e, &slice_dir))?; match autotiler::create_dim_split_template( &slice_model, &tentative_info, &slice_dir, Some(&traced_shapes), ) { Ok(tmpl_path) => { let tmpl_rel = tmpl_path .strip_prefix(&output_dir) .map_err(|_| { DsperseError::Slicer(format!( "dim-split template path {} is not under output dir {}", tmpl_path.display(), output_dir.display() )) })? .to_string_lossy() .into_owned(); tracing::info!( slice = seg_idx, estimated = detection.estimated_constraints, num_groups = detection.num_groups, split_kind = ?detection.split_kind, "dim-split detected and template created" ); dim_split_info.insert(seg_idx, (detection, Some(tmpl_rel))); } Err(e) => { tracing::warn!( slice = seg_idx, estimated = detection.estimated_constraints, error = %e, "dim-split detected but template creation failed; \ slice will be skipped during compilation" ); // Record detection with no template path so the // compiler knows this slice was over-budget and // should be skipped rather than falling through // to monolithic compilation. dim_split_info.insert(seg_idx, (detection, None)); } } } } } let slices = build_slice_metadata( &analysis, &slice_points, &segment_ranges, &traced_shapes, &tiled_info, &dim_split_info, ); let mut metadata = ModelMetadata { original_model: analysis.original_model.clone().unwrap_or_default(), model_type: analysis.model_type.clone(), input_shape: analysis.input_shape.clone(), output_shapes: analysis.output_shapes.clone(), output_names: analysis.output_names.clone(), slice_points: slice_points[..slice_points.len().saturating_sub(1)].to_vec(), slices, dsperse_version: None, dsperse_rev: None, jstprove_version: None, jstprove_rev: None, traced_shapes: Some(traced_shapes), traced_types: Some(traced_types), original_model_path: Some("model.onnx".to_string()), folded_constant_names: folded_constants.into_iter().collect(), }; metadata.stamp_version(); metadata.save(&output_dir.join(crate::utils::paths::METADATA_FILE))?; tracing::info!( slices = metadata.slices.len(), tiled = tiled_info.len(), "slicing complete" ); Ok(metadata) } fn build_slice_metadata( analysis: &AnalysisResult, _slice_points: &[usize], segment_ranges: &[(usize, usize)], traced_shapes: &HashMap>, tiled_info: &HashMap, dim_split_info: &HashMap)>, ) -> Vec { let mut slices = Vec::new(); for (seg_idx, &(start, end)) in segment_ranges.iter().enumerate() { let dependencies = analyzer::get_segment_dependencies(analysis, start, end); let shape = build_shape_from_traced(analysis, start, end, &dependencies, traced_shapes); let filename = format!("slice_{seg_idx}.onnx"); let relative_path = format!("slice_{seg_idx}/payload/{filename}"); let mut tiling = None; let mut channel_split = None; if let Some(detection) = tiled_info.get(&seg_idx) { match detection { autotiler::TilingDetection::Spatial { input_name, output_name, input_names, ndim, c_in, c_out, h, w, tile_size: actual_tile, halo, tiles_y, tiles_x, out_tile, stride, } => { tiling = Some(crate::schema::tiling::TilingInfo { slice_idx: seg_idx, tile_size: *actual_tile as usize, num_tiles: (*tiles_y * *tiles_x) as usize, tiles_y: *tiles_y as usize, tiles_x: *tiles_x as usize, halo: *halo, out_tile: *out_tile, stride: *stride, c_in: *c_in as usize, c_out: *c_out as usize, input_name: input_name.clone(), output_name: output_name.clone(), input_names: input_names.clone(), ndim: *ndim as usize, h: *h as usize, w: *w as usize, tile: Some(crate::schema::tiling::TileInfo { path: format!("slice_{seg_idx}/payload/tiles/tile.onnx"), conv_out: *out_tile, jstprove_circuit_path: None, }), tiles: None, segment_size: None, total_elements: None, original_shape: vec![], }); } autotiler::TilingDetection::FixedSegment { input_name, output_name, input_names, total_elements, segment_size, num_segments, original_shape, } => { tiling = Some(crate::schema::tiling::TilingInfo { slice_idx: seg_idx, tile_size: *segment_size as usize, num_tiles: *num_segments as usize, tiles_y: *num_segments as usize, tiles_x: 1, halo: [0, 0, 0, 0], out_tile: [*segment_size, 1], stride: [1, 1], c_in: 1, c_out: 1, input_name: input_name.clone(), output_name: output_name.clone(), input_names: input_names.clone(), ndim: 1, h: *total_elements as usize, w: 1, tile: Some(crate::schema::tiling::TileInfo { path: format!("slice_{seg_idx}/payload/tiles/tile.onnx"), conv_out: [*segment_size, 1], jstprove_circuit_path: None, }), tiles: None, segment_size: Some(*segment_size as usize), total_elements: Some(*total_elements as usize), original_shape: original_shape.clone(), }); } autotiler::TilingDetection::ChannelSplit { input_name, output_name, c_in, c_out, h, w, num_groups, channels_per_group, } => { channel_split = Some(crate::schema::tiling::ChannelSplitInfo { slice_idx: seg_idx, c_in: *c_in as usize, c_out: *c_out as usize, num_groups: *num_groups as usize, channels_per_group: *channels_per_group as usize, input_name: input_name.clone(), output_name: output_name.clone(), h: *h as usize, w: *w as usize, out_h: 0, out_w: 0, groups: Vec::new(), bias_path: None, }); } } } let dim_split = dim_split_info .get(&seg_idx) .map(|(d, tmpl_rel)| DimSplitInfo::from_detection(d, seg_idx, tmpl_rel.clone())); slices.push(SliceMetadata { index: seg_idx, filename: filename.clone(), path: format!("payload/{filename}"), relative_path, shape: SliceShapeWrapper { tensor_shape: shape, }, dependencies, tiling, channel_split, dim_split, compilation: Default::default(), slice_metadata: None, slice_metadata_relative_path: None, }); } slices } fn build_shape_from_traced( _analysis: &AnalysisResult, _start: usize, _end: usize, dependencies: &Dependencies, traced_shapes: &HashMap>, ) -> TensorShape { let input_shapes: Vec> = dependencies .filtered_inputs .iter() .filter_map(|name| traced_shapes.get(name).cloned()) .collect(); let output_shapes: Vec> = dependencies .output .iter() .filter_map(|name| traced_shapes.get(name).cloned()) .collect(); TensorShape { input: input_shapes, output: output_shapes, } } fn determine_slice_points( analysis: &AnalysisResult, tile_size: Option, jstprove_ops: &[&str], model: &onnx_proto::ModelProto, traced_shapes: &HashMap>, ) -> Vec { let mut points: HashSet = HashSet::new(); for node in analysis.nodes.values() { if !node.parameter_details.is_empty() { points.insert(node.index); } } let mut sorted_points: Vec = points.into_iter().collect(); sorted_points.sort(); sorted_points = isolate_conv(&sorted_points, analysis); sorted_points = isolate_expensive_ops(&sorted_points, analysis, model, traced_shapes); sorted_points = optimize_jstprove_slices(&sorted_points, analysis, jstprove_ops); if tile_size.is_some() { sorted_points = optimize_for_tiling(&sorted_points, analysis); } sorted_points = filter_constant_only_slices(&sorted_points, analysis); sorted_points = merge_control_flow_segments(&sorted_points, analysis); sorted_points.sort(); sorted_points.dedup(); complete_slice_points(&mut sorted_points, analysis); sorted_points } fn optimize_points( points: &[usize], analysis: &AnalysisResult, mutate: impl FnOnce(&mut HashSet, &[&NodeAnalysis], usize), ) -> Vec { let mut updated: HashSet = points.iter().copied().collect(); let mut sorted_nodes: Vec<&NodeAnalysis> = analysis.nodes.values().collect(); sorted_nodes.sort_by_key(|n| n.index); let max_idx = sorted_nodes.last().map(|n| n.index).unwrap_or(0); mutate(&mut updated, &sorted_nodes, max_idx); let mut v: Vec = updated.into_iter().filter(|&p| p <= max_idx).collect(); v.sort(); v } fn is_spatial_primary(op: &str) -> bool { op == "Conv" || op == "MaxPool" } /// Insert slice points before AND after every ONNX node whose /// estimated constraint count exceeds /// [`autotiler::MAX_ESTIMATED_CONSTRAINTS`]. Each "expensive" op /// (large MatMul, LayerNormalization, Softmax, etc.) becomes a /// single-node slice so the dim-split detector sees an unambiguous /// shape and the runner doesn't need to trace which axis lives where /// through Transpose / Reshape neighbours. Small ops keep their /// existing grouping for circuit catalog reuse. fn isolate_expensive_ops( points: &[usize], analysis: &AnalysisResult, model: &onnx_proto::ModelProto, traced_shapes: &HashMap>, ) -> Vec { use jstprove_circuits::api::{EstimationConfig, estimate_op_constraints}; let cfg = EstimationConfig::bn254_defaults(); let threshold = autotiler::MAX_ESTIMATED_CONSTRAINTS; // Build a parallel index: ONNX-node-index -> &NodeProto so we can // resolve input/output tensor names per slicer-node. let onnx_nodes: Vec<&onnx_proto::NodeProto> = model .graph .as_ref() .map(|g| g.node.iter().collect()) .unwrap_or_default(); // Resolve a tensor's traced shape strictly: every dim must be a // concrete positive value. Coercing dynamic / -1 / 0 dims to 1 // would silently drive the cost estimate to ~zero and let the // very nodes this pass exists to isolate sneak through. Returning // `None` for an unresolved tensor is the signal to pessimistically // isolate the node anyway. let to_usize_shape = |name: &String| -> Option> { let shape = traced_shapes.get(name)?; let mut out = Vec::with_capacity(shape.len()); for &d in shape { if d <= 0 { return None; } out.push(d as usize); } Some(out) }; // Pure elementwise binary ops (Add / Sub / Mul / Div / Pow) are // never isolated. This is a coupling to jstprove_circuits's // single-op-slice invariants: when an isolated slice contains // exactly one Div with a runtime divisor, one Mul / Sub between // operands of broadcast-incompatible shapes, or one Pow whose // exponent is a non-constant tensor, the per-op layer builder // rejects the slice with a strict-mode error. When the same // pattern appears inside a larger multi-op slice the // dim-split / LayerNorm fusion machinery rewrites the // surrounding subgraph and the strict check passes. These ops // are also cheap to compile in absolute terms, so isolating them // buys little proving wall-clock and surfaces the strict-mode // failure more often. // // TODO: revisit when jstprove_circuits relaxes the single-op // invariants (or exposes a "permissive" mode) so we can drop // this exemption and let the autotiler decide based on cost. let elementwise_skip: HashSet<&str> = ["Add", "Sub", "Mul", "Div", "Pow"].into_iter().collect(); optimize_points(points, analysis, |updated, sorted_nodes, max_idx| { for node in sorted_nodes { if elementwise_skip.contains(node.node_type.as_str()) { continue; } let Some(onnx_node) = onnx_nodes.get(node.index) else { continue; }; // ONNX node inputs / outputs use "" to denote an // unbound optional slot (e.g. Conv with no bias, GRU // with no initial_h). Treating those as unresolved // boundary tensors makes every node carrying an empty // slot pessimistically isolate, even when the real // boundary tensors are fully shape-resolved. Skip the // empty entries so estimate_op_constraints sees only // the real boundary tensors. let in_shapes: Option>> = onnx_node .input .iter() .filter(|name| !name.is_empty()) .map(&to_usize_shape) .collect(); let out_shapes: Option>> = onnx_node .output .iter() .filter(|name| !name.is_empty()) .map(&to_usize_shape) .collect(); // If any boundary tensor is unresolved we cannot give an // honest cost estimate; isolate pessimistically so the // downstream compile path sees a single-op slice and can // either compile it successfully or skip it cleanly, // rather than silently grouping an unbounded op. let isolate = match (in_shapes, out_shapes) { (Some(ins), Some(outs)) => { estimate_op_constraints(&node.node_type, &ins, &outs, &cfg) > threshold } _ => true, }; if isolate { updated.insert(node.index); if node.index < max_idx { updated.insert(node.index + 1); } } } }) } fn isolate_conv(points: &[usize], analysis: &AnalysisResult) -> Vec { optimize_points(points, analysis, |updated, sorted_nodes, max_idx| { for (pos, node) in sorted_nodes.iter().enumerate() { if is_spatial_primary(&node.node_type) { updated.insert(node.index); let mut produced: HashSet<&str> = node .dependencies .output .iter() .map(|s| s.as_str()) .collect(); let mut end = pos + 1; while end < sorted_nodes.len() { let candidate = sorted_nodes[end]; if !super::is_slice_passthrough(&candidate.node_type) { break; } let consumes_produced = candidate.dependencies.input.iter().any(|inp| { !analysis.initializer_names.contains(inp) && produced.contains(inp.as_str()) }); if !consumes_produced { break; } for out in &candidate.dependencies.output { produced.insert(out.as_str()); } end += 1; } if end < sorted_nodes.len() && sorted_nodes[end].index <= max_idx { updated.insert(sorted_nodes[end].index); } } } }) } fn optimize_jstprove_slices( points: &[usize], analysis: &AnalysisResult, jstprove_ops: &[&str], ) -> Vec { optimize_points(points, analysis, |updated, sorted_nodes, _max_idx| { let is_supported = |n: &NodeAnalysis| jstprove_ops.contains(&n.node_type.as_str()); for i in 0..sorted_nodes.len().saturating_sub(1) { if is_supported(sorted_nodes[i]) != is_supported(sorted_nodes[i + 1]) { updated.insert(sorted_nodes[i + 1].index); } } }) } fn optimize_for_tiling(points: &[usize], analysis: &AnalysisResult) -> Vec { optimize_points(points, analysis, |updated, sorted_nodes, _max_idx| { let is_tileable = |n: &NodeAnalysis| { n.node_type == "Conv" || n.node_type == "MaxPool" || super::is_elementwise(&n.node_type) }; for i in 0..sorted_nodes.len().saturating_sub(1) { let curr = sorted_nodes[i]; let next = sorted_nodes[i + 1]; if !is_tileable(curr) && next.node_type == "Relu" { continue; } if is_tileable(curr) != is_tileable(next) { updated.insert(next.index); } } }) } fn filter_constant_only_slices(points: &[usize], analysis: &AnalysisResult) -> Vec { if points.is_empty() { return points.to_vec(); } let nodes_by_idx: HashMap = analysis.nodes.values().map(|n| (n.index, n)).collect(); let mut to_remove: HashSet = HashSet::new(); for (i, &end_idx) in points.iter().enumerate() { let start_idx = if i > 0 { points[i - 1] } else { 0 }; if start_idx == end_idx { continue; } let all_constant = (start_idx..end_idx).all(|idx| { nodes_by_idx .get(&idx) .map(|n| n.node_type == "Constant") .unwrap_or(true) }); if all_constant { to_remove.insert(end_idx); } } if !to_remove.is_empty() { tracing::info!(count = to_remove.len(), "merged constant-only slices"); } points .iter() .filter(|p| !to_remove.contains(p)) .copied() .collect() } fn merge_control_flow_segments(points: &[usize], analysis: &AnalysisResult) -> Vec { let output_to_node_idx: HashMap<&str, usize> = analysis .nodes .values() .flat_map(|n| { n.dependencies .output .iter() .map(move |o| (o.as_str(), n.index)) }) .collect(); let mut to_remove: HashSet = HashSet::new(); for node in analysis.nodes.values() { if !super::is_control_flow(&node.node_type) { continue; } for inp in &node.dependencies.input { if let Some(&producer_idx) = output_to_node_idx.get(inp.as_str()) { for &pt in points { if pt > producer_idx && pt <= node.index { to_remove.insert(pt); } } } } } if !to_remove.is_empty() { tracing::info!( count = to_remove.len(), "removed slice points to preserve control flow node dependencies" ); } points .iter() .filter(|p| !to_remove.contains(p)) .copied() .collect() } fn complete_slice_points(points: &mut Vec, analysis: &AnalysisResult) { let max_index = analysis.nodes.values().map(|n| n.index).max().unwrap_or(0); let end = max_index + 1; if !points.contains(&0) { points.push(0); } if !points.contains(&end) { points.push(end); } points.sort(); points.dedup(); } pub(crate) fn broadcast_shapes(shapes: &[&Vec]) -> Option> { if shapes.is_empty() { return None; } let max_rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0); let mut result = vec![1i64; max_rank]; for shape in shapes { let offset = max_rank - shape.len(); for (i, &dim) in shape.iter().enumerate() { let ri = offset + i; if result[ri] == 1 { result[ri] = dim; } else if dim != 1 && dim != result[ri] { return None; } } } Some(result) } #[cfg(test)] mod tests { use super::*; use analyzer::NodeDependencies; fn make_analysis_with_params(nodes: Vec<(&str, usize, &str, bool)>) -> AnalysisResult { let mut node_map = HashMap::new(); for (name, index, op_type, has_params) in &nodes { let mut parameter_details = HashMap::new(); if *has_params { parameter_details.insert( format!("{}_weight", name), analyzer::ParameterDetail { shape: vec![3, 3], size: 9, }, ); } node_map.insert( name.to_string(), NodeAnalysis { index: *index, slice_name: format!("{}_{}", op_type, index), node_type: op_type.to_string(), parameter_details, dependencies: NodeDependencies { input: vec![], output: vec![], }, }, ); } AnalysisResult { original_model: None, model_type: "ONNX".to_string(), node_count: nodes.len(), initializer_count: 0, input_shape: vec![], output_shapes: vec![], output_names: vec![], opset_version: Some(18), nodes: node_map, initializer_names: HashSet::new(), } } const TEST_OPS: &[&str] = &["Conv", "Gemm", "MatMul"]; #[test] fn complete_slice_points_adds_boundaries() { let analysis = make_analysis_with_params(vec![ ("a", 0, "Conv", false), ("b", 1, "Relu", false), ("c", 2, "Conv", false), ]); let mut points = vec![1]; complete_slice_points(&mut points, &analysis); assert!(points.contains(&0)); assert!(points.contains(&3)); assert!(points.contains(&1)); } #[test] fn complete_slice_points_already_complete() { let analysis = make_analysis_with_params(vec![("a", 0, "Conv", false), ("b", 1, "Relu", false)]); let mut points = vec![0, 2]; complete_slice_points(&mut points, &analysis); assert_eq!(points, vec![0, 2]); } #[test] fn complete_slice_points_deduplicates() { let analysis = make_analysis_with_params(vec![("a", 0, "Conv", false)]); let mut points = vec![0, 0, 1, 1]; complete_slice_points(&mut points, &analysis); assert_eq!(points, vec![0, 1]); } #[test] fn isolate_conv_inserts_boundaries() { let analysis = make_analysis_with_params(vec![ ("a", 0, "Conv", false), ("b", 1, "Relu", false), ("c", 2, "MaxPool", false), ("d", 3, "Conv", false), ("e", 4, "Relu", false), ]); let points = vec![0, 3]; let result = isolate_conv(&points, &analysis); assert!(result.contains(&0)); assert!(result.contains(&1)); assert!(result.contains(&3)); assert!(result.contains(&4)); } #[test] fn isolate_conv_no_convs() { let analysis = make_analysis_with_params(vec![("a", 0, "Relu", false), ("b", 1, "Reshape", false)]); let points = vec![0]; let result = isolate_conv(&points, &analysis); assert_eq!(result, vec![0]); } #[test] fn isolate_maxpool_gets_boundary() { let analysis = make_analysis_with_params(vec![("a", 0, "Relu", false), ("b", 1, "MaxPool", false)]); let points = vec![0]; let result = isolate_conv(&points, &analysis); assert_eq!(result, vec![0, 1]); } #[test] fn optimize_jstprove_slices_splits_at_boundary() { let analysis = make_analysis_with_params(vec![ ("a", 0, "Conv", false), ("b", 1, "Relu", false), ("c", 2, "Conv", false), ]); let points = vec![0]; let result = optimize_jstprove_slices(&points, &analysis, TEST_OPS); assert!(result.contains(&1)); assert!(result.contains(&2)); } #[test] fn optimize_jstprove_slices_all_supported() { let analysis = make_analysis_with_params(vec![("a", 0, "Conv", false), ("b", 1, "Conv", false)]); let points = vec![0, 1]; let result = optimize_jstprove_slices(&points, &analysis, TEST_OPS); assert_eq!(result, vec![0, 1]); } #[test] fn optimize_for_tiling_maxpool_stays_grouped() { let analysis = make_analysis_with_params(vec![ ("a", 0, "Conv", false), ("b", 1, "Relu", false), ("c", 2, "MaxPool", false), ("d", 3, "Conv", false), ]); let points = vec![0, 3]; let result = optimize_for_tiling(&points, &analysis); assert!(!result.contains(&2)); } #[test] fn optimize_for_tiling_splits_at_non_tileable() { let analysis = make_analysis_with_params(vec![ ("a", 0, "Conv", false), ("b", 1, "Relu", false), ("c", 2, "Reshape", false), ("d", 3, "Conv", false), ]); let points = vec![0, 3]; let result = optimize_for_tiling(&points, &analysis); assert!(result.contains(&2)); } #[test] fn optimize_for_tiling_relu_after_non_tileable_kept() { let analysis = make_analysis_with_params(vec![ ("a", 0, "MaxPool", false), ("b", 1, "Relu", false), ("c", 2, "Conv", false), ]); let points = vec![0, 2]; let result = optimize_for_tiling(&points, &analysis); assert!(!result.contains(&1)); } #[test] fn filter_constant_only_slices_removes_constant_segments() { let analysis = make_analysis_with_params(vec![ ("a", 0, "Constant", false), ("b", 1, "Constant", false), ("c", 2, "Conv", false), ("d", 3, "Relu", false), ]); let points = vec![2, 4]; let result = filter_constant_only_slices(&points, &analysis); assert!(!result.contains(&2)); assert!(result.contains(&4)); } #[test] fn filter_constant_only_slices_keeps_non_constant() { let analysis = make_analysis_with_params(vec![("a", 0, "Conv", false), ("b", 1, "Relu", false)]); let points = vec![1, 2]; let result = filter_constant_only_slices(&points, &analysis); assert_eq!(result, vec![1, 2]); } #[test] fn filter_constant_only_slices_empty_points() { let analysis = make_analysis_with_params(vec![("a", 0, "Conv", false)]); let result = filter_constant_only_slices(&[], &analysis); assert!(result.is_empty()); } #[test] fn determine_slice_points_includes_parameterized_nodes() { let analysis = make_analysis_with_params(vec![ ("conv0", 0, "Conv", true), ("relu0", 1, "Relu", false), ("conv1", 2, "Conv", true), ("relu1", 3, "Relu", false), ]); let model = onnx_proto::ModelProto::default(); let traced = HashMap::new(); let points = determine_slice_points(&analysis, None, TEST_OPS, &model, &traced); assert!(points.contains(&0)); assert!(points.contains(&2)); let max = *points.last().unwrap(); assert_eq!(max, 4); } #[test] fn determine_slice_points_with_tile_size() { let analysis = make_analysis_with_params(vec![ ("conv0", 0, "Conv", true), ("relu0", 1, "Relu", false), ("pool", 2, "MaxPool", false), ("conv1", 3, "Conv", true), ]); let model = onnx_proto::ModelProto::default(); let traced = HashMap::new(); let points = determine_slice_points(&analysis, Some(1024), TEST_OPS, &model, &traced); assert!(points.contains(&0)); assert!(points.len() >= 3); } type NodeSpec<'a> = (&'a str, usize, &'a str, bool, Vec<&'a str>, Vec<&'a str>); fn make_analysis_with_deps(nodes: Vec>) -> AnalysisResult { let mut node_map = HashMap::new(); for (name, index, op_type, has_params, inputs, outputs) in &nodes { let mut parameter_details = HashMap::new(); if *has_params { parameter_details.insert( format!("{}_weight", name), analyzer::ParameterDetail { shape: vec![3, 3], size: 9, }, ); } node_map.insert( name.to_string(), NodeAnalysis { index: *index, slice_name: format!("{}_{}", op_type, index), node_type: op_type.to_string(), parameter_details, dependencies: NodeDependencies { input: inputs.iter().map(|s| s.to_string()).collect(), output: outputs.iter().map(|s| s.to_string()).collect(), }, }, ); } AnalysisResult { original_model: None, model_type: "ONNX".to_string(), node_count: nodes.len(), initializer_count: 0, input_shape: vec![], output_shapes: vec![], output_names: vec![], opset_version: Some(18), nodes: node_map, initializer_names: HashSet::new(), } } #[test] fn merge_control_flow_removes_boundary_between_producer_and_loop() { let analysis = make_analysis_with_deps(vec![ ("conv0", 0, "Conv", true, vec!["x"], vec!["conv_out"]), ( "relu0", 1, "Relu", false, vec!["conv_out"], vec!["relu_out"], ), ( "matmul0", 2, "MatMul", true, vec!["relu_out"], vec!["mm_out"], ), ( "loop0", 3, "Loop", false, vec!["trip", "cond", "init", "relu_out"], vec!["loop_out"], ), ]); let points = vec![0, 2, 4]; let result = merge_control_flow_segments(&points, &analysis); assert!( !result.contains(&2), "slice point 2 separates relu0 (producer of relu_out at idx 1) from Loop (idx 3); must be removed: {:?}", result ); } #[test] fn merge_control_flow_preserves_unrelated_boundaries() { let analysis = make_analysis_with_deps(vec![ ("conv0", 0, "Conv", true, vec!["x"], vec!["conv_out"]), ( "relu0", 1, "Relu", false, vec!["conv_out"], vec!["relu_out"], ), ( "conv1", 2, "Conv", true, vec!["relu_out"], vec!["conv1_out"], ), ( "relu1", 3, "Relu", false, vec!["conv1_out"], vec!["relu1_out"], ), ( "loop0", 4, "Loop", false, vec!["trip", "cond", "relu1_out"], vec!["loop_out"], ), ]); let points = vec![0, 2, 5]; let result = merge_control_flow_segments(&points, &analysis); assert!( result.contains(&2), "boundary at 2 is between conv0/relu0 and conv1/relu1, should be preserved since Loop only depends on relu1_out (idx 3): {:?}", result ); } #[test] fn merge_control_flow_no_control_flow_ops() { let analysis = make_analysis_with_deps(vec![ ("conv0", 0, "Conv", true, vec!["x"], vec!["conv_out"]), ("relu0", 1, "Relu", false, vec!["conv_out"], vec!["y"]), ]); let points = vec![0, 1, 2]; let result = merge_control_flow_segments(&points, &analysis); assert_eq!(result, vec![0, 1, 2]); } /// Regression for PR #183: isolate_conv's inner grouping walk /// must treat the LAYOUT_OPS set (Reshape / Transpose / /// Flatten / Squeeze / Unsqueeze / Gather) as passthroughs so /// that Conv -> Reshape -> MatMul places the trailing compile /// boundary on the heavy MatMul rather than on the Reshape /// that sits between them. Before the is_slice_passthrough /// split these ops were absent from is_shape_preserving and /// the walk terminated on the Reshape, isolating it into its /// own slice. #[test] fn isolate_conv_absorbs_reshape_then_boundaries_on_matmul() { let analysis = make_analysis_with_deps(vec![ ("conv0", 0, "Conv", true, vec!["x"], vec!["conv_out"]), ( "reshape0", 1, "Reshape", false, vec!["conv_out", "shape"], vec!["reshape_out"], ), ( "matmul0", 2, "MatMul", true, vec!["reshape_out", "matmul0_weight"], vec!["matmul_out"], ), ]); let points = vec![0, 3]; let result = isolate_conv(&points, &analysis); assert!( result.contains(&0), "isolate_conv should insert a boundary at the Conv itself: {result:?}" ); assert!( result.contains(&2), "is_slice_passthrough should absorb Reshape into the Conv slice and place the trailing boundary on MatMul at index 2: {result:?}" ); assert!( !result.contains(&1), "Reshape at index 1 must not become its own slice boundary when it sits between a Conv and a heavy op: {result:?}" ); } /// Transpose + Squeeze variant so we also cover the other /// layout ops added to LAYOUT_OPS. #[test] fn isolate_conv_absorbs_transpose_chain_then_boundaries_on_matmul() { let analysis = make_analysis_with_deps(vec![ ("conv0", 0, "Conv", true, vec!["x"], vec!["conv_out"]), ( "transpose0", 1, "Transpose", false, vec!["conv_out"], vec!["trans_out"], ), ( "squeeze0", 2, "Squeeze", false, vec!["trans_out"], vec!["sq_out"], ), ( "matmul0", 3, "MatMul", true, vec!["sq_out", "matmul0_weight"], vec!["matmul_out"], ), ]); let points = vec![0, 4]; let result = isolate_conv(&points, &analysis); assert!(result.contains(&0)); assert!( result.contains(&3), "Transpose + Squeeze chain should absorb into the Conv slice, leaving MatMul at index 3 as the boundary: {result:?}" ); assert!(!result.contains(&1)); assert!(!result.contains(&2)); } /// Counter-case: a layout op whose input is NOT produced by /// the preceding Conv slice must still break the walk, so the /// consumes_produced guard is exercised. #[test] fn isolate_conv_stops_when_passthrough_consumes_external_input() { let analysis = make_analysis_with_deps(vec![ ("conv0", 0, "Conv", true, vec!["x"], vec!["conv_out"]), ( "reshape0", 1, "Reshape", false, // Reshape consumes an external tensor, not // conv_out, so is_slice_passthrough being true is // not sufficient to absorb it. vec!["external_y", "shape"], vec!["reshape_out"], ), ]); let points = vec![0, 2]; let result = isolate_conv(&points, &analysis); assert!(result.contains(&0)); assert!( result.contains(&1), "Reshape that doesn't consume any conv-produced tensor should remain the trailing boundary: {result:?}" ); } } ================================================ FILE: crates/dsperse/src/slicer/self_div_rewrite.rs ================================================ use std::collections::HashMap; use super::onnx_proto::{ModelProto, TensorProto}; /// Graph rewrite placeholder: detecting `Div(X, X)` and collapsing it to a /// constant-ones tensor is only sound when the element type is a floating /// point dtype AND every element of X is finite AND non-zero. Without a /// traced-properties side channel carrying that guarantee the rewrite would /// silently turn `0 / 0 = NaN` and integer underflow into `1`. /// /// The earlier implementation rewrote unconditionally and is preserved here /// as documentation so that a follow-up can plug it in once /// `traced_dtypes` / `traced_all_finite_nonzero` maps are available. pub fn rewrite_self_div_to_one( _model: &mut ModelProto, _traced_shapes: &mut HashMap>, ) -> usize { // Intentionally a no-op: see module doc. Re-enable behind a proper // traced-properties guard once available. let _ = TensorProto::FLOAT; 0 } ================================================ FILE: crates/dsperse/src/slicer/trace.rs ================================================ use std::collections::{HashMap, HashSet}; use std::path::Path; use super::onnx_proto::ModelProto; use crate::error::{DsperseError, Result}; pub(crate) struct TraceResult { pub shapes: HashMap>, pub types: HashMap, } pub(crate) fn fold_and_trace_via_tract( onnx_path: &Path, model: &ModelProto, ) -> Result { use tract_onnx::prelude::*; use tract_onnx::tract_hir::infer::InferenceSimplePlan; let loop_bodies = collect_loop_bodies(model); let tract_path = tag_all_outputs(onnx_path, model)?; let tract_model = std::sync::Arc::new( tract_onnx::onnx() .model_for_path(&tract_path) .map_err(|e| DsperseError::Slicer(format!("tract load: {e}")))?, ); if let Err(e) = std::fs::remove_file(&tract_path) { tracing::debug!(path = %tract_path.display(), error = %e, "failed to remove tagged model"); } let plan = InferenceSimplePlan::new(tract_model.clone()) .map_err(|e| DsperseError::Slicer(format!("plan creation: {e}")))?; let mut state = tract_onnx::tract_core::plan::SimpleState::new(&plan) .map_err(|e| DsperseError::Slicer(format!("state creation: {e}")))?; let mut input_tvs: TVec = tvec![]; for outlet in tract_model .input_outlets() .map_err(|e| DsperseError::Slicer(format!("input outlets: {e}")))? { let fact = tract_model .outlet_fact(*outlet) .map_err(|e| DsperseError::Slicer(format!("input fact: {e}")))?; let tensor = if let Ok(tf) = fact.to_typed_fact() { let shape: Vec = tf .shape .iter() .map(|d| d.to_i64().unwrap_or(1).max(1) as usize) .collect(); Tensor::zero_dt(tf.datum_type, &shape) .map_err(|e| DsperseError::Slicer(format!("zero tensor: {e}")))? } else { Tensor::zero::(&[1]).expect("scalar f32 allocation") }; input_tvs.push(tensor.into_tvalue()); } let shapes_cell = std::cell::RefCell::new(HashMap::>>::new()); let dtypes_cell = std::cell::RefCell::new(HashMap::>::new()); let failed_nodes = std::cell::RefCell::new(HashSet::::new()); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { state.run_plan_with_eval(input_tvs, |session, op_state, node, inputs| { let tainted = node .inputs .iter() .any(|inp| failed_nodes.borrow().contains(&inp.node)); let outputs = if tainted { failed_nodes.borrow_mut().insert(node.id); let fallback = inputs.first().cloned().unwrap_or_else(|| { Tensor::zero::(&[1]) .expect("scalar f32 allocation") .into_tvalue() }); let n = node.outputs.len().max(1); (0..n).map(|_| fallback.clone()).collect() } else { let coerced = crate::backend::onnx::coerce_tdim_inputs(&inputs); let eval_result = if let Some(st) = op_state { st.eval(session, node.op.as_op(), coerced) } else { node.op.eval(coerced) }; match eval_result { Ok(o) => o, Err(e) => { if let Some(synth) = synthesize_loop_outputs(&node.name, &inputs, &loop_bodies) { tracing::info!( node = %node.name, outputs = synth.len(), "synthesized Loop output tensors from body subgraph shapes" ); synth } else { tracing::warn!( node = %node.name, op = %node.op.name(), error = %e, "op eval failed, using input[0] shape as fallback" ); failed_nodes.borrow_mut().insert(node.id); let fallback = inputs.first().cloned().unwrap_or_else(|| { Tensor::zero::(&[1]) .expect("scalar f32 allocation") .into_tvalue() }); let n = node.outputs.len().max(1); (0..n).map(|_| fallback.clone()).collect() } } } }; let node_shapes: Vec> = outputs .iter() .map(|t| t.shape().iter().map(|&d| d as i64).collect()) .collect(); let node_dtypes: Vec = outputs .iter() .map(|t| datum_type_to_onnx(t.datum_type())) .collect(); shapes_cell.borrow_mut().insert(node.id, node_shapes); dtypes_cell.borrow_mut().insert(node.id, node_dtypes); Ok::<_, TractError>(outputs) }) })); match &result { Ok(Ok(_)) => tracing::info!("tract inference run succeeded"), Ok(Err(e)) => { tracing::warn!(error = %e, "tract inference run produced errors; partial shapes may be available") } Err(_) => { return Err(DsperseError::Slicer( "tract inference panicked; no shape data recovered".into(), )); } } let run_shapes = shapes_cell.into_inner(); let run_dtypes = dtypes_cell.into_inner(); let failed = failed_nodes.into_inner(); tracing::info!( traced_nodes = run_shapes.len(), "constant folding and shape capture complete" ); let mut shapes: HashMap> = HashMap::new(); let mut types: HashMap = HashMap::new(); for (node_id, node_shapes) in &run_shapes { if failed.contains(node_id) { continue; } let node_dtypes = run_dtypes.get(node_id); for (slot, shape) in node_shapes.iter().enumerate() { let dt = node_dtypes.and_then(|d| d.get(slot)).copied().unwrap_or(1) as i32; // 1 = FLOAT let outlet = OutletId::new(*node_id, slot); if let Some(label) = tract_model.outlet_label(outlet) && !label.is_empty() { shapes.insert(label.to_string(), shape.clone()); types.insert(label.to_string(), dt); } let node = tract_model.node(*node_id); if !node.name.is_empty() { if slot == 0 { shapes .entry(node.name.clone()) .or_insert_with(|| shape.clone()); types.entry(node.name.clone()).or_insert(dt); } let qualified = format!("{}:{}", node.name, slot); shapes .entry(qualified.clone()) .or_insert_with(|| shape.clone()); types.entry(qualified).or_insert(dt); } } } if let Some(graph) = &model.graph { let mut extra: Vec<(String, Vec, Option)> = Vec::new(); for n in &graph.node { for (slot, out) in n.output.iter().enumerate() { if out.is_empty() || shapes.contains_key(out) { continue; } let key = if slot == 0 { n.name.clone() } else { format!("{}:{}", n.name, slot) }; if let Some(shape) = shapes.get(&key) { let dt = types.get(&key).copied(); extra.push((out.clone(), shape.clone(), dt)); } } } for (name, shape, dt) in extra { shapes.insert(name.clone(), shape); if let Some(dt) = dt { types.insert(name, dt); } } for init in &graph.initializer { if !init.dims.is_empty() { shapes .entry(init.name.clone()) .or_insert_with(|| init.dims.clone()); } if init.data_type != 0 { types.entry(init.name.clone()).or_insert(init.data_type); } } for inp in &graph.input { if let Some(shape) = super::onnx_shapes::shape_from_value_info(inp) { shapes.entry(inp.name.clone()).or_insert(shape); } if let Some(dt) = super::onnx_shapes::elem_type_from_value_info(inp) { types.entry(inp.name.clone()).or_insert(dt); } } resolve_absorbed_nodes(graph, &mut shapes); } tracing::info!(tensors = shapes.len(), "shape trace complete"); Ok(TraceResult { shapes, types }) } /// Save a copy of the ONNX model with every node output declared as a graph /// output. This forces tract to preserve outlet labels for all intermediate /// tensors, preventing them from being lost during op fusion. fn tag_all_outputs(onnx_path: &Path, model: &ModelProto) -> Result { let mut tagged = model.clone(); if let Some(ref mut graph) = tagged.graph { let existing: HashSet = graph.output.iter().map(|o| o.name.clone()).collect(); for node in &graph.node { for out in &node.output { if !out.is_empty() && !existing.contains(out) { graph.output.push(super::onnx_proto::ValueInfoProto { name: out.clone(), ..Default::default() }); } } } } let dir = onnx_path.parent().unwrap_or_else(|| Path::new(".")); let tagged_path = dir.join(format!("_tract_tagged_{}.onnx", std::process::id())); super::onnx_proto::save_model(&tagged, &tagged_path)?; Ok(tagged_path) } fn onnx_elem_type_to_datum(onnx_type: i32) -> Option { use tract_onnx::prelude::DatumType; match onnx_type { 1 => Some(DatumType::F32), 2 => Some(DatumType::U8), 3 => Some(DatumType::I8), 5 => Some(DatumType::I16), 6 => Some(DatumType::I32), 7 => Some(DatumType::I64), 9 => Some(DatumType::Bool), 10 => Some(DatumType::F16), 11 => Some(DatumType::F64), 12 => Some(DatumType::U32), 13 => Some(DatumType::U64), _ => None, } } fn datum_type_to_onnx(dt: tract_onnx::prelude::DatumType) -> u8 { use tract_onnx::prelude::DatumType; match dt { DatumType::F32 => 1, DatumType::U8 => 2, DatumType::I8 => 3, DatumType::U16 => 4, DatumType::I16 => 5, DatumType::I32 => 6, DatumType::I64 => 7, DatumType::Bool => 9, DatumType::F16 => 10, DatumType::F64 => 11, DatumType::U32 => 12, DatumType::U64 => 13, _ => 1, } } struct LoopBody { num_loop_carried: usize, num_scan: usize, scan_body_output_shapes: Vec>>, scan_body_output_dtypes: Vec>, } /// Collect Loop node body metadata from the ONNX graph. For scan outputs /// whose shapes can be statically determined from the body subgraph, store /// the body-side shape (without the leading trip-count dimension). fn collect_loop_bodies(model: &ModelProto) -> HashMap { let graph = match model.graph.as_ref() { Some(g) => g, None => return HashMap::new(), }; let mut known: HashMap> = HashMap::new(); for init in &graph.initializer { if !init.dims.is_empty() { known.insert(init.name.clone(), init.dims.clone()); } } for vi in graph .input .iter() .chain(graph.value_info.iter()) .chain(graph.output.iter()) { if let Some(shape) = super::onnx_shapes::shape_from_value_info(vi) { known.insert(vi.name.clone(), shape); } } let mut result = HashMap::new(); for node in &graph.node { if node.op_type != "Loop" { continue; } let body = match node .attribute .iter() .find(|a| a.name == "body") .and_then(|a| a.g.as_ref()) { Some(b) => b, None => continue, }; let num_loop_carried = node.input.len().saturating_sub(2); let num_body_out = body.output.len().saturating_sub(1); let num_scan = num_body_out.saturating_sub(num_loop_carried); let mut scan_shapes = Vec::with_capacity(num_scan); let mut scan_dtypes = Vec::with_capacity(num_scan); for j in 0..num_scan { let body_out_idx = 1 + num_loop_carried + j; let body_vi = body.output.get(body_out_idx); let shape = body_vi.and_then(|vi| resolve_body_tensor_shape(&vi.name, body, graph, &known)); let dtype = body_vi.and_then(super::onnx_shapes::elem_type_from_value_info); scan_shapes.push(shape); scan_dtypes.push(dtype); } result.insert( node.name.clone(), LoopBody { num_loop_carried, num_scan, scan_body_output_shapes: scan_shapes, scan_body_output_dtypes: scan_dtypes, }, ); } result } /// During tract evaluation, when a Loop node fails, produce correctly-shaped /// zero tensors so downstream nodes receive valid inputs and are not tainted. /// /// Loop-carried output shapes come directly from the actual input tensors /// (inputs\[2..\]). Scan output shapes come from the pre-analyzed body /// subgraph with a leading dimension of 1 (single iteration assumption). fn synthesize_loop_outputs( node_name: &str, inputs: &[tract_onnx::prelude::TValue], loop_bodies: &HashMap, ) -> Option> { use tract_onnx::prelude::*; let body = loop_bodies.get(node_name)?; let mut tvs: TVec = tvec![]; for i in 0..body.num_loop_carried { let init_tensor = inputs.get(i + 2)?; let shape: Vec = init_tensor.shape().to_vec(); let tensor = Tensor::zero_dt(init_tensor.datum_type(), &shape).ok()?; tvs.push(tensor.into_tvalue()); } for j in 0..body.num_scan { let body_shape = body.scan_body_output_shapes.get(j)?; let shape: Vec = match body_shape { Some(bs) => { let mut s = vec![1usize]; s.extend(bs.iter().map(|&d| d.max(1) as usize)); s } None => { tracing::warn!( node = node_name, scan_idx = j, "scan output shape unknown, using [1,1] placeholder" ); vec![1, 1] } }; let dt = body .scan_body_output_dtypes .get(j) .and_then(|d| *d) .and_then(onnx_elem_type_to_datum) .unwrap_or(DatumType::F32); let tensor = Tensor::zero_dt(dt, &shape).ok()?; tvs.push(tensor.into_tvalue()); } Some(tvs) } /// Resolve shapes for ONNX graph nodes that tract absorbed or renamed, /// making them invisible in the tract shape output. Iterates until no /// more progress, using only rules already defined in the slicer module /// (shape-preserving ops, binary broadcast). fn resolve_absorbed_nodes( graph: &super::onnx_proto::GraphProto, shapes: &mut HashMap>, ) { let max_passes = 10; for _ in 0..max_passes { let mut progress = false; for node in &graph.node { for out in &node.output { if out.is_empty() || shapes.contains_key(out) { continue; } let op = node.op_type.as_str(); let shape = if super::is_shape_preserving(op) || op == "Identity" { node.input.first().and_then(|inp| shapes.get(inp).cloned()) } else if super::is_binary_arithmetic(op) { let resolved: Vec<&Vec> = node.input.iter().filter_map(|i| shapes.get(i)).collect(); let non_empty = node.input.iter().filter(|i| !i.is_empty()).count(); if resolved.len() == non_empty { super::onnx_slicer::broadcast_shapes(&resolved) } else { None } } else { None }; if let Some(s) = shape { shapes.insert(out.clone(), s); progress = true; } } } if !progress { break; } } } fn resolve_body_tensor_shape( name: &str, body: &super::onnx_proto::GraphProto, outer_graph: &super::onnx_proto::GraphProto, known_shapes: &HashMap>, ) -> Option> { resolve_body_tensor_shape_inner(name, body, outer_graph, known_shapes, 0) } fn resolve_body_tensor_shape_inner( name: &str, body: &super::onnx_proto::GraphProto, outer_graph: &super::onnx_proto::GraphProto, known_shapes: &HashMap>, depth: usize, ) -> Option> { if depth > 32 { return None; } for vi in body.output.iter().chain(body.value_info.iter()) { if vi.name == name && let Some(shape) = super::onnx_shapes::shape_from_value_info(vi) { return Some(shape); } } for init in body .initializer .iter() .chain(outer_graph.initializer.iter()) { if init.name == name && !init.dims.is_empty() { return Some(init.dims.to_vec()); } } if let Some(shape) = known_shapes.get(name) { return Some(shape.clone()); } let producer = body .node .iter() .find(|n| n.output.contains(&name.to_string()))?; let op = producer.op_type.as_str(); if super::is_shape_preserving(op) || op == "Identity" { let inp = producer.input.first()?; return resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1); } if super::is_binary_arithmetic(op) { let resolved: Vec> = producer .input .iter() .filter_map(|inp| { resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1) }) .collect(); let refs: Vec<&Vec> = resolved.iter().collect(); return super::onnx_slicer::broadcast_shapes(&refs); } if op == "Concat" { let axis = super::onnx_proto::get_attribute_int(producer, "axis")?; let input_shapes: Vec> = producer .input .iter() .filter_map(|inp| { resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1) }) .collect(); if input_shapes.len() != producer.input.len() || input_shapes.is_empty() { return None; } let rank = input_shapes[0].len() as i64; if axis < -rank || axis >= rank { return None; } let axis_idx = if axis < 0 { (rank + axis) as usize } else { axis as usize }; let mut result = input_shapes[0].clone(); for shape in &input_shapes[1..] { if let Some(d) = result.get_mut(axis_idx) { *d += shape.get(axis_idx).copied().unwrap_or(0); } } return Some(result); } if op == "Transpose" { let inp = producer.input.first()?; let in_shape = resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1)?; let perm = &producer.attribute.iter().find(|a| a.name == "perm")?.ints; let result: Vec = perm .iter() .filter_map(|&p| in_shape.get(p as usize).copied()) .collect(); if result.len() == in_shape.len() { return Some(result); } } None } ================================================ FILE: crates/dsperse/src/utils/io.rs ================================================ use std::collections::HashMap; use std::path::Path; use ndarray::{ArrayD, Axis, IxDyn}; use rmpv::Value; use crate::error::{DsperseError, Result}; pub fn read_msgpack(path: &Path) -> Result { let data = crate::utils::limits::read_checked(path)?; rmp_serde::from_slice(&data).map_err(Into::into) } pub fn write_msgpack(path: &Path, value: &Value) -> Result<()> { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?; } let data = rmp_serde::to_vec_named(value)?; std::fs::write(path, data).map_err(|e| DsperseError::io(e, path)) } pub fn extract_input_data(value: &Value) -> Option<&Value> { map_get_ref(value, "input_data") .or_else(|| map_get_ref(value, "input")) .or_else(|| map_get_ref(value, "data")) .or_else(|| map_get_ref(value, "inputs")) } pub fn flatten_nested_list(value: &Value) -> Vec { let mut result = Vec::new(); flatten_recursive(value, &mut result); result } fn flatten_recursive(value: &Value, out: &mut Vec) { match value { Value::F64(f) => out.push(*f), Value::F32(f) => out.push(*f as f64), Value::Integer(n) => { if let Some(f) = n.as_f64() { out.push(f); } else { tracing::warn!(number = ?n, "flatten_recursive: dropping non-f64 representable integer"); } } Value::Array(arr) => { for item in arr { flatten_recursive(item, out); } } other => { tracing::warn!(variant = %other, "flatten_recursive: dropping non-numeric value during flattening"); } } } pub fn infer_shape(value: &Value) -> Vec { let mut shape = Vec::new(); let mut current = value; while let Value::Array(arr) = current { shape.push(arr.len()); if let Some(first) = arr.first() { current = first; } else { break; } } shape } pub fn value_to_arrayd(value: &Value) -> Result> { let flat = flatten_nested_list(value); let shape = infer_shape(value); if flat.is_empty() { return ArrayD::from_shape_vec(IxDyn(&shape), vec![]) .map_err(|e| DsperseError::Pipeline(format!("empty arrayd: {e}"))); } if shape.is_empty() && flat.len() == 1 { return ArrayD::from_shape_vec(IxDyn(&[]), flat) .map_err(|e| DsperseError::Pipeline(format!("scalar arrayd: {e}"))); } let product: usize = shape.iter().product(); if product != flat.len() || shape.is_empty() { tracing::warn!( flat_len = flat.len(), ?shape, product, "shape mismatch, falling back to 1D" ); return ArrayD::from_shape_vec(IxDyn(&[flat.len()]), flat) .map_err(|e| DsperseError::Pipeline(format!("arrayd reshape fallback: {e}"))); } ArrayD::from_shape_vec(IxDyn(&shape), flat) .map_err(|e| DsperseError::Pipeline(format!("arrayd reshape: {e}"))) } pub fn arrayd_to_value(arr: &ArrayD) -> Value { match arr.ndim() { 0 => Value::F64(arr[IxDyn(&[])]), 1 => { let vals: Vec = arr.iter().map(|&v| Value::F64(v)).collect(); Value::Array(vals) } _ => { let vals: Vec = (0..arr.shape()[0]) .map(|i| { let sub = arr.index_axis(Axis(0), i).to_owned(); arrayd_to_value(&sub) }) .collect(); Value::Array(vals) } } } pub fn gather_inputs_from_cache( cache: &HashMap>, inputs: &[String], ) -> Result> { let mut collected = Vec::new(); let mut missing = Vec::new(); for name in inputs { if let Some(val) = cache.get(name) { collected.push(val.clone()); } else { missing.push(name.clone()); } } if collected.is_empty() { return Err(DsperseError::Pipeline(format!( "no cached tensor found for inputs: {inputs:?}" ))); } if !missing.is_empty() { return Err(DsperseError::Pipeline(format!( "missing tensors in cache: {missing:?} (found {} of {})", collected.len(), inputs.len() ))); } if collected.len() == 1 { return Ok(collected.into_iter().next().unwrap()); } if collected[0].ndim() == 0 { return Err(DsperseError::Pipeline( "cannot concatenate 0-dimensional tensors".into(), )); } let ref_trailing = collected[0].shape()[1..].to_vec(); let ref_product: usize = ref_trailing.iter().product(); let batch = collected[0].shape()[0]; for (i, arr) in collected.iter_mut().enumerate().skip(1) { let trailing = &arr.shape()[1..]; if trailing != ref_trailing.as_slice() { let product: usize = trailing.iter().product(); if product == ref_product && arr.shape()[0] == batch { let orig_shape: Vec = arr.shape().to_vec(); let mut target = vec![batch]; target.extend_from_slice(&ref_trailing); let owned = std::mem::replace(arr, ArrayD::zeros(ndarray::IxDyn(&[]))); *arr = owned .into_shape_with_order(ndarray::IxDyn(&target)) .map_err(|e| { DsperseError::Pipeline(format!( "gather reshape input {i} from {orig_shape:?} to {target:?}: {e}", )) })?; } else { return Err(DsperseError::Pipeline(format!( "shape mismatch at input {}: expected trailing dims {:?}, got {:?}", i, ref_trailing, trailing ))); } } } ndarray::concatenate( ndarray::Axis(0), &collected.iter().map(|a| a.view()).collect::>(), ) .map_err(|e| DsperseError::Pipeline(format!("concat inputs: {e}"))) } pub fn build_msgpack_map(entries: Vec<(&str, Value)>) -> Value { Value::Map( entries .into_iter() .map(|(k, v)| (Value::String(k.into()), v)) .collect(), ) } pub fn map_get_ref<'a>(value: &'a Value, key: &str) -> Option<&'a Value> { match value { Value::Map(entries) => entries.iter().find_map(|(k, v)| { if k.as_str().is_some_and(|s| s == key) { Some(v) } else { None } }), _ => None, } } ================================================ FILE: crates/dsperse/src/utils/limits.rs ================================================ use std::io::Read; use std::path::Path; use crate::error::{DsperseError, Result}; pub fn reject_symlink(path: &Path) -> Result<()> { let m = std::fs::symlink_metadata(path).map_err(|e| DsperseError::io(e, path))?; if m.is_symlink() { return Err(DsperseError::Archive(format!( "symlink not permitted: {}", path.file_name() .and_then(|n| n.to_str()) .unwrap_or("") ))); } Ok(()) } fn open_nofollow(path: &Path) -> Result { #[cfg(unix)] { use std::os::unix::fs::OpenOptionsExt; std::fs::OpenOptions::new() .read(true) .custom_flags(libc::O_NOFOLLOW) .open(path) .map_err(|e| { if e.raw_os_error() == Some(libc::ELOOP) { DsperseError::Archive(format!( "symlink not permitted: {}", path.file_name() .and_then(|n| n.to_str()) .unwrap_or("") )) } else { DsperseError::io(e, path) } }) } #[cfg(not(unix))] { reject_symlink(path)?; std::fs::File::open(path).map_err(|e| DsperseError::io(e, path)) } } pub fn read_checked(path: &Path) -> Result> { let mut file = open_nofollow(path)?; let mut buf = Vec::new(); file.read_to_end(&mut buf) .map_err(|e| DsperseError::io(e, path))?; Ok(buf) } pub fn read_to_string_checked(path: &Path) -> Result { let mut file = open_nofollow(path)?; let mut buf = String::new(); file.read_to_string(&mut buf) .map_err(|e| DsperseError::io(e, path))?; Ok(buf) } #[cfg(test)] mod tests { use super::*; #[test] fn reject_symlink_on_regular_file() { let tmp = tempfile::NamedTempFile::new().unwrap(); assert!(reject_symlink(tmp.path()).is_ok()); } #[cfg(unix)] #[test] fn reject_symlink_on_symlink() { let dir = tempfile::tempdir().unwrap(); let target = dir.path().join("target"); std::fs::write(&target, b"data").unwrap(); let link = dir.path().join("link"); std::os::unix::fs::symlink(&target, &link).unwrap(); assert!(reject_symlink(&link).is_err()); } #[test] fn read_checked_normal() { let tmp = tempfile::NamedTempFile::new().unwrap(); std::fs::write(tmp.path(), b"hello").unwrap(); let data = read_checked(tmp.path()).unwrap(); assert_eq!(data, b"hello"); } #[test] fn read_to_string_checked_normal() { let tmp = tempfile::NamedTempFile::new().unwrap(); std::fs::write(tmp.path(), "hello world").unwrap(); let s = read_to_string_checked(tmp.path()).unwrap(); assert_eq!(s, "hello world"); } } ================================================ FILE: crates/dsperse/src/utils/metadata.rs ================================================ use std::path::Path; use crate::error::{DsperseError, Result}; use crate::schema::RunMetadata; pub fn load_run_metadata(path: &Path) -> Result { let data = crate::utils::limits::read_checked(path)?; rmp_serde::from_slice(&data).map_err(Into::into) } pub fn save_run_metadata(path: &Path, meta: &RunMetadata) -> Result<()> { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?; } let data = rmp_serde::to_vec_named(meta)?; std::fs::write(path, data).map_err(|e| DsperseError::io(e, path)) } ================================================ FILE: crates/dsperse/src/utils/mod.rs ================================================ pub mod io; pub mod limits; pub mod metadata; pub mod paths; ================================================ FILE: crates/dsperse/src/utils/paths.rs ================================================ use std::path::{Component, Path, PathBuf}; use crate::error::{DsperseError, Result}; pub const METADATA_FILE: &str = "metadata.msgpack"; pub const INPUT_FILE: &str = "input.msgpack"; pub const OUTPUT_FILE: &str = "output.msgpack"; pub const WITNESS_FILE: &str = "witness.bin"; pub const PROOF_FILE: &str = "proof.bin"; pub fn resolve_relative_path(base: &Path, relative: &str) -> Result { let rel = Path::new(relative); if rel.is_absolute() { return Err(DsperseError::Archive(format!( "absolute path in metadata is not permitted: {relative}" ))); } for component in rel.components() { match component { Component::ParentDir => { return Err(DsperseError::Archive(format!( "path traversal component in metadata is not permitted: {relative}" ))); } Component::RootDir | Component::Prefix(_) => { return Err(DsperseError::Archive(format!( "invalid path component in metadata: {relative}" ))); } _ => {} } } Ok(base.join(rel)) } pub fn relativize_path(path: &Path, base: &Path) -> String { path.strip_prefix(base) .map(|p| p.to_string_lossy().to_string()) .unwrap_or_else(|_| path.to_string_lossy().to_string()) } pub fn slice_dir_path(root: &Path, index: usize) -> PathBuf { root.join(format!("slice_{index}")) } pub fn find_metadata_path(dir: &Path) -> Option { let direct = dir.join(METADATA_FILE); if direct.exists() { return Some(direct); } let slices = dir.join("slices").join(METADATA_FILE); if slices.exists() { return Some(slices); } None } #[cfg(test)] mod tests { use super::*; #[test] fn resolve_relative_normal_path() { let base = Path::new("/tmp/slices"); let result = resolve_relative_path(base, "payload/model.onnx").unwrap(); assert_eq!(result, PathBuf::from("/tmp/slices/payload/model.onnx")); } #[test] fn resolve_relative_rejects_absolute() { let base = Path::new("/tmp/slices"); assert!(resolve_relative_path(base, "/etc/passwd").is_err()); } #[test] fn resolve_relative_rejects_parent_dir() { let base = Path::new("/tmp/slices"); assert!(resolve_relative_path(base, "../../../etc/passwd").is_err()); } #[test] fn resolve_relative_rejects_embedded_parent() { let base = Path::new("/tmp/slices"); assert!(resolve_relative_path(base, "payload/../../../etc/passwd").is_err()); } #[test] fn resolve_relative_allows_current_dir() { let base = Path::new("/tmp/slices"); let result = resolve_relative_path(base, "./model.onnx").unwrap(); assert_eq!(result, PathBuf::from("/tmp/slices/./model.onnx")); } #[test] fn resolve_relative_empty_string() { let base = Path::new("/tmp/slices"); let result = resolve_relative_path(base, "").unwrap(); assert_eq!(result, PathBuf::from("/tmp/slices/")); } } ================================================ FILE: crates/dsperse/src/version.rs ================================================ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DsperseVersion { pub dsperse_version: String, pub dsperse_rev: Option, pub jstprove_version: String, pub jstprove_rev: Option, } pub fn dsperse_artifact_version() -> DsperseVersion { let jst_ver = jstprove_circuits::api::jstprove_artifact_version(); DsperseVersion { dsperse_version: env!("CARGO_PKG_VERSION").to_string(), dsperse_rev: option_env!("DSPERSE_GIT_REV").map(String::from), jstprove_version: jst_ver.crate_version, jstprove_rev: Some(jst_ver.git_rev), } } ================================================ FILE: crates/dsperse/tests/integration_slice.rs ================================================ use std::path::Path; use dsperse::schema::metadata::ModelMetadata; fn test_models_dir() -> &'static Path { Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/../../tests/models")) } #[test] fn slice_net_model() { let model_path = test_models_dir().join("net/model.onnx"); assert!( model_path.exists(), "test model not found at {}", model_path.display() ); let tmp = tempfile::tempdir().expect("create temp dir"); let output_dir = tmp.path().join("slices"); let metadata = dsperse::slicer::slice_model( &model_path, Some(&output_dir), None, jstprove_circuits::ProofSystem::Expander.supported_ops(), None, ) .expect("slice_model"); assert!(!metadata.slices.is_empty()); assert_eq!(metadata.model_type, "ONNX"); assert!(!metadata.input_shape.is_empty()); assert!(!metadata.output_shapes.is_empty()); let meta_path = output_dir.join("metadata.msgpack"); assert!(meta_path.exists(), "metadata.msgpack must be written"); let model_onnx = output_dir.join("model.onnx"); assert!( model_onnx.exists(), "model.onnx must be copied to output dir" ); let loaded = ModelMetadata::load(&meta_path).expect("load metadata"); assert_eq!(loaded.slices.len(), metadata.slices.len()); assert!(loaded.traced_shapes.is_some()); assert!(loaded.original_model_path.is_some()); assert_eq!(loaded.original_model_path.as_deref(), Some("model.onnx")); } #[test] fn slice_doom_model() { let model_path = test_models_dir().join("doom/model.onnx"); assert!( model_path.exists(), "test model not found at {}", model_path.display() ); let tmp = tempfile::tempdir().expect("create temp dir"); let output_dir = tmp.path().join("slices"); let metadata = dsperse::slicer::slice_model( &model_path, Some(&output_dir), None, jstprove_circuits::ProofSystem::Expander.supported_ops(), None, ) .expect("slice_model"); assert!(!metadata.slices.is_empty()); for (i, slice) in metadata.slices.iter().enumerate() { assert_eq!(slice.index, i); assert!(!slice.dependencies.input.is_empty()); assert!(!slice.dependencies.output.is_empty()); } } #[test] fn slice_net_model_remainder() { let model_path = test_models_dir().join("net/model.onnx"); assert!( model_path.exists(), "test model not found at {}", model_path.display() ); let tmp = tempfile::tempdir().expect("create temp dir"); let output_dir = tmp.path().join("slices"); let metadata = dsperse::slicer::slice_model( &model_path, Some(&output_dir), None, jstprove_circuits::ProofSystem::Remainder.supported_ops(), None, ) .expect("slice_model with Remainder"); assert!(!metadata.slices.is_empty()); assert_eq!(metadata.model_type, "ONNX"); } #[test] fn slice_with_tile_size() { let model_path = test_models_dir().join("net/model.onnx"); assert!( model_path.exists(), "test model not found at {}", model_path.display() ); let tmp = tempfile::tempdir().expect("create temp dir"); let output_dir = tmp.path().join("slices"); let metadata = dsperse::slicer::slice_model( &model_path, Some(&output_dir), Some(8), jstprove_circuits::ProofSystem::Expander.supported_ops(), None, ) .expect("slice_model"); assert!(!metadata.slices.is_empty()); let meta_path = output_dir.join("metadata.msgpack"); assert!(meta_path.exists()); } #[test] fn slice_metadata_roundtrip_from_disk() { let model_path = test_models_dir().join("net/model.onnx"); assert!( model_path.exists(), "test model not found at {}", model_path.display() ); let tmp = tempfile::tempdir().expect("create temp dir"); let output_dir = tmp.path().join("slices"); let original = dsperse::slicer::slice_model( &model_path, Some(&output_dir), None, jstprove_circuits::ProofSystem::Expander.supported_ops(), None, ) .expect("slice_model"); let meta_path = output_dir.join("metadata.msgpack"); let deserialized = ModelMetadata::load(&meta_path).expect("load metadata"); assert_eq!(original.slices.len(), deserialized.slices.len()); assert_eq!(original.original_model, deserialized.original_model); assert_eq!(original.input_shape, deserialized.input_shape); assert_eq!(original.output_shapes, deserialized.output_shapes); assert_eq!(original.traced_shapes, deserialized.traced_shapes); } #[test] fn materialize_from_manifest() { let model_path = test_models_dir().join("net/model.onnx"); assert!( model_path.exists(), "test model not found at {}", model_path.display() ); let tmp = tempfile::tempdir().expect("create temp dir"); let output_dir = tmp.path().join("slices"); let metadata = dsperse::slicer::slice_model( &model_path, Some(&output_dir), None, jstprove_circuits::ProofSystem::Expander.supported_ops(), None, ) .expect("slice_model"); dsperse::slicer::materializer::ensure_all_slices_materialized(&output_dir, &metadata) .expect("materialize all slices"); for slice in &metadata.slices { let slice_dir = output_dir.join(format!("slice_{}", slice.index)); assert!( slice_dir.exists(), "slice dir must exist after materialization: {}", slice_dir.display() ); let payload_dir = slice_dir.join("payload"); assert!(payload_dir.exists(), "payload dir must exist"); let onnx_file = payload_dir.join(&slice.filename); assert!( onnx_file.exists(), "onnx file must exist: {}", onnx_file.display() ); } } #[test] fn resolve_onnx_points_to_existing_file_after_materialize() { let model_path = test_models_dir().join("net/model.onnx"); assert!( model_path.exists(), "test model not found at {}", model_path.display() ); let tmp = tempfile::tempdir().expect("create temp dir"); let output_dir = tmp.path().join("slices"); let metadata = dsperse::slicer::slice_model( &model_path, Some(&output_dir), None, jstprove_circuits::ProofSystem::Expander.supported_ops(), None, ) .expect("slice_model"); dsperse::slicer::materializer::ensure_all_slices_materialized(&output_dir, &metadata) .expect("materialize all slices"); let loaded = ModelMetadata::load(&output_dir.join("metadata.msgpack")).expect("load metadata"); assert!(!loaded.slices.is_empty()); for slice in &loaded.slices { let resolved = slice.resolve_onnx(&output_dir).unwrap(); assert!( resolved.is_file(), "resolve_onnx for slice {} must point to a regular file, got: {}", slice.index, resolved.display() ); assert!( resolved.starts_with(&output_dir), "resolved path must start with output_dir" ); let resolved_path = resolved.to_string_lossy(); let output_str = output_dir.to_string_lossy(); let count = resolved_path.matches(output_str.as_ref()).count(); assert_eq!( count, 1, "output_dir must appear exactly once in resolved path, got {count}" ); } } ================================================ FILE: crates/dsperse/tests/schema_roundtrip.rs ================================================ use std::path::Path; use dsperse::schema::*; #[test] fn model_metadata_roundtrip() { let json = r#"{ "original_model": "model.onnx", "model_type": "onnx", "input_shape": [[1, 3, 32, 32]], "output_shapes": [[1, 10]], "slice_points": [2, 5], "slices": [ { "index": 0, "filename": "slice_0.onnx", "path": "/tmp/slices/slice_0/payload/slice_0.onnx", "relative_path": "slice_0/payload/slice_0.onnx", "shape": { "tensor_shape": { "input": [[1, 3, 32, 32]], "output": [[1, 16, 16, 16]] } }, "dependencies": { "input": ["input"], "output": ["conv1_out"], "filtered_inputs": ["input"] }, "compilation": { "jstprove": { "compiled": true, "tiled": false, "weights_as_inputs": false, "files": { "compiled": "jstprove/circuit.txt", "settings": "jstprove/settings.json" } } } }, { "index": 1, "filename": "slice_1.onnx", "path": "/tmp/slices/slice_1/payload/slice_1.onnx", "relative_path": "slice_1/payload/slice_1.onnx", "shape": { "tensor_shape": { "input": [[1, 16, 16, 16]], "output": [[1, 10]] } }, "dependencies": { "input": ["conv1_out"], "output": ["output"], "filtered_inputs": ["conv1_out"] }, "tiling": { "slice_idx": 1, "tile_size": 8, "num_tiles": 4, "tiles_y": 2, "tiles_x": 2, "halo": [1, 1], "out_tile": [8, 8], "stride": [1, 1], "c_in": 16, "c_out": 32, "input_name": "conv1_out", "output_name": "conv2_out", "tile": { "path": "tiles/tile.onnx", "conv_out": [8, 8] }, "tiles": [ {"path": "tiles/tile.onnx", "conv_out": [8, 8]}, {"path": "tiles/tile.onnx", "conv_out": [8, 8]} ] }, "compilation": { "jstprove": { "compiled": false, "tiled": false, "weights_as_inputs": false, "files": {} } } } ] }"#; let meta: ModelMetadata = serde_json::from_str(json).unwrap(); assert_eq!(meta.original_model, "model.onnx"); assert_eq!(meta.slices.len(), 2); assert_eq!(meta.slice_points, vec![2, 5]); let s0 = &meta.slices[0]; assert_eq!(s0.index, 0); assert!(s0.compilation.jstprove.compiled); assert_eq!( s0.compilation.jstprove.files.compiled.as_deref(), Some("jstprove/circuit.txt") ); assert!(s0.tiling.is_none()); let s1 = &meta.slices[1]; assert!(s1.tiling.is_some()); let tiling = s1.tiling.as_ref().unwrap(); assert_eq!(tiling.num_tiles, 4); assert_eq!(tiling.halo, [1, 1, 1, 1]); assert_eq!(tiling.tiles.as_ref().unwrap().len(), 2); let msgpack_bytes = rmp_serde::to_vec_named(&meta).unwrap(); let meta2: ModelMetadata = rmp_serde::from_slice(&msgpack_bytes).unwrap(); assert_eq!(meta2.slices.len(), 2); assert_eq!(meta2.slices[0].index, 0); } #[test] fn run_metadata_roundtrip() { let json = r#"{ "slices": { "slice_0": { "path": "slice_0/payload/slice_0.onnx", "input_shape": [[1, 3, 32, 32]], "output_shape": [[1, 16, 16, 16]], "dependencies": { "input": ["input"], "output": ["conv1_out"], "filtered_inputs": ["input"] }, "backend": "jstprove", "circuit_path": "slice_0/payload/jstprove/circuit.txt" } }, "execution_chain": { "head": "slice_0", "nodes": { "slice_0": { "slice_id": "slice_0", "primary": "slice_0/payload/jstprove/circuit.txt", "fallbacks": ["slice_0/payload/slice_0.onnx"], "use_circuit": true, "next": null, "circuit_path": "slice_0/payload/jstprove/circuit.txt", "onnx_path": "slice_0/payload/slice_0.onnx", "backend": "jstprove" } }, "fallback_map": {}, "execution_results": [], "jstprove_proved_slices": 0, "jstprove_verified_slices": 0 } }"#; let meta: RunMetadata = serde_json::from_str(json).unwrap(); assert_eq!(meta.slices.len(), 1); let slice = meta.get_slice("slice_0").unwrap(); assert_eq!(slice.backend, BackendKind::Jstprove); assert_eq!( slice.jstprove_circuit_path.as_deref(), Some("slice_0/payload/jstprove/circuit.txt") ); let chain = &meta.execution_chain; assert_eq!(chain.head.as_deref(), Some("slice_0")); assert!(chain.nodes["slice_0"].use_circuit); let circuit_slices: Vec<_> = meta.iter_circuit_slices().collect(); assert_eq!(circuit_slices.len(), 1); assert_eq!(circuit_slices[0].0, "slice_0"); let msgpack_bytes = rmp_serde::to_vec_named(&meta).unwrap(); let meta2: RunMetadata = rmp_serde::from_slice(&msgpack_bytes).unwrap(); assert_eq!(meta2.slices.len(), 1); } #[test] fn execution_info_with_tiles() { let json = r#"{ "method": "tiled", "success": true, "tile_exec_infos": [ {"tile_idx": 0, "success": true, "method": "jstprove_gen_witness", "time_sec": 1.5}, {"tile_idx": 1, "success": true, "method": "jstprove_gen_witness", "time_sec": 1.3}, {"tile_idx": 2, "success": false, "error": "timeout", "time_sec": 30.0} ] }"#; let info: ExecutionInfo = serde_json::from_str(json).unwrap(); assert!(info.success); assert_eq!(info.tile_exec_infos.len(), 3); assert!(!info.tile_exec_infos[2].success); assert_eq!(info.tile_exec_infos[2].error.as_deref(), Some("timeout")); } #[test] fn channel_split_roundtrip() { let json = r#"{ "slice_idx": 2, "c_in": 64, "c_out": 128, "num_groups": 4, "channels_per_group": 16, "input_name": "relu1_out", "output_name": "conv2_out", "h": 16, "w": 16, "groups": [ {"group_idx": 0, "c_start": 0, "c_end": 16, "path": "channel_groups/group_0.onnx"}, {"group_idx": 1, "c_start": 16, "c_end": 32, "path": "channel_groups/group_1.onnx"} ], "bias_path": "channel_groups/bias.msgpack" }"#; let info: ChannelSplitInfo = serde_json::from_str(json).unwrap(); assert_eq!(info.num_groups, 4); assert_eq!(info.groups.len(), 2); assert_eq!(info.groups[0].c_end, 16); assert_eq!( info.bias_path.as_deref(), Some("channel_groups/bias.msgpack") ); let msgpack_bytes = rmp_serde::to_vec_named(&info).unwrap(); let info2: ChannelSplitInfo = rmp_serde::from_slice(&msgpack_bytes).unwrap(); assert_eq!(info2.num_groups, 4); } #[test] fn compilation_files_aliases() { let json1 = r#"{"compiled": "circuit.txt"}"#; let json2 = r#"{"compiled_circuit": "circuit.txt"}"#; let json3 = r#"{"circuit": "circuit.txt"}"#; let f1: CompilationFiles = serde_json::from_str(json1).unwrap(); let f2: CompilationFiles = serde_json::from_str(json2).unwrap(); let f3: CompilationFiles = serde_json::from_str(json3).unwrap(); assert_eq!(f1.compiled.as_deref(), Some("circuit.txt")); assert_eq!(f2.compiled.as_deref(), Some("circuit.txt")); assert_eq!(f3.compiled.as_deref(), Some("circuit.txt")); } #[test] fn backend_serde() { assert_eq!( serde_json::to_string(&BackendKind::Jstprove).unwrap(), r#""jstprove""# ); assert_eq!( serde_json::to_string(&BackendKind::Onnx).unwrap(), r#""onnx""# ); let b: BackendKind = serde_json::from_str(r#""jstprove""#).unwrap(); assert_eq!(b, BackendKind::Jstprove); let b: BackendKind = serde_json::from_str(r#""JSTPROVE""#).unwrap(); assert_eq!(b, BackendKind::Jstprove); } #[test] fn tensor_shape_i64_deserialization() { let json = r#"{ "input": [[1, 3, 224, 224]], "output": [[1, 1000]] }"#; let shape: TensorShape = serde_json::from_str(json).unwrap(); assert_eq!(shape.input, vec![vec![1i64, 3, 224, 224]]); assert_eq!(shape.output, vec![vec![1i64, 1000]]); let msgpack_bytes = rmp_serde::to_vec_named(&shape).unwrap(); let shape2: TensorShape = rmp_serde::from_slice(&msgpack_bytes).unwrap(); assert_eq!(shape2.input, shape.input); assert_eq!(shape2.output, shape.output); } #[test] fn tensor_shape_rejects_non_integer() { let json = r#"{"input": [[1, "hello", 3]], "output": []}"#; let result: std::result::Result = serde_json::from_str(json); assert!(result.is_err()); } #[test] fn run_slice_metadata_i64_shapes() { let json = r#"{ "path": "slice_0/payload/slice_0.onnx", "input_shape": [[1, 3, 32, 32]], "output_shape": [[1, 16, 16, 16]], "dependencies": { "input": ["input"], "output": ["conv1_out"], "filtered_inputs": ["input"] }, "backend": "onnx" }"#; let meta: RunSliceMetadata = serde_json::from_str(json).unwrap(); assert_eq!(meta.input_shape, vec![vec![1i64, 3, 32, 32]]); assert_eq!(meta.output_shape, vec![vec![1i64, 16, 16, 16]]); let msgpack_bytes = rmp_serde::to_vec_named(&meta).unwrap(); let meta2: RunSliceMetadata = rmp_serde::from_slice(&msgpack_bytes).unwrap(); assert_eq!(meta2.input_shape, meta.input_shape); assert_eq!(meta2.output_shape, meta.output_shape); } #[test] fn resolve_onnx_uses_relative_path_not_absolute() { let json = r#"{ "index": 0, "filename": "slice_0.onnx", "path": "/original/cwd/slices/slice_0/payload/slice_0.onnx", "relative_path": "slice_0/payload/slice_0.onnx", "shape": {"tensor_shape": {"input": [[1, 3, 32, 32]], "output": [[1, 10]]}}, "dependencies": {"input": [], "output": [], "filtered_inputs": []}, "compilation": {"jstprove": {"compiled": false, "tiled": false, "weights_as_inputs": false, "files": {}}} }"#; let slice: SliceMetadata = serde_json::from_str(json).unwrap(); let slices_dir = Path::new("/relocated/slices"); let resolved = slice.resolve_onnx(slices_dir).unwrap(); assert_eq!( resolved, Path::new("/relocated/slices/slice_0/payload/slice_0.onnx"), "resolve_onnx must use relative_path (relative to slices_dir), not the absolute path field" ); assert!( !resolved.to_string_lossy().contains("/original/"), "resolved path must not contain the original CWD-relative path" ); } ================================================ FILE: crates/dsperse/tests/sn2_contract.rs ================================================ use std::path::Path; use ndarray::{ArrayD, IxDyn}; use rmpv::Value; fn make_value_array(vals: &[f64]) -> Value { Value::Array(vals.iter().map(|&v| Value::F64(v)).collect()) } fn make_value_2d(rows: &[&[f64]]) -> Value { Value::Array(rows.iter().map(|row| make_value_array(row)).collect()) } fn make_value_3d(planes: &[&[&[f64]]]) -> Value { Value::Array(planes.iter().map(|plane| make_value_2d(plane)).collect()) } fn make_value_4d(blocks: &[&[&[&[f64]]]]) -> Value { Value::Array(blocks.iter().map(|block| make_value_3d(block)).collect()) } #[test] fn value_arrayd_roundtrip_1d() { let input = make_value_array(&[1.0, 2.0, 3.0, 4.0]); let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap(); assert_eq!(arr.shape(), &[4]); assert_eq!(arr[IxDyn(&[0])], 1.0); assert_eq!(arr[IxDyn(&[3])], 4.0); let output = dsperse::utils::io::arrayd_to_value(&arr); assert_eq!(output, input); } #[test] fn value_arrayd_roundtrip_2d() { let input = make_value_2d(&[&[1.0, 2.0], &[3.0, 4.0]]); let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap(); assert_eq!(arr.shape(), &[2, 2]); assert_eq!(arr[IxDyn(&[0, 0])], 1.0); assert_eq!(arr[IxDyn(&[1, 1])], 4.0); let output = dsperse::utils::io::arrayd_to_value(&arr); assert_eq!(output, input); } #[test] fn value_arrayd_roundtrip_3d() { let input = make_value_3d(&[&[&[1.0, 2.0], &[3.0, 4.0]], &[&[5.0, 6.0], &[7.0, 8.0]]]); let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap(); assert_eq!(arr.shape(), &[2, 2, 2]); assert_eq!(arr[IxDyn(&[0, 0, 0])], 1.0); assert_eq!(arr[IxDyn(&[1, 1, 1])], 8.0); let output = dsperse::utils::io::arrayd_to_value(&arr); assert_eq!(output, input); } #[test] fn value_arrayd_roundtrip_4d() { let input = make_value_4d(&[&[&[&[0.5, 1.5], &[2.5, 3.5]], &[&[4.5, 5.5], &[6.5, 7.5]]]]); let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap(); assert_eq!(arr.shape(), &[1, 2, 2, 2]); assert_eq!(arr[IxDyn(&[0, 0, 0, 0])], 0.5); assert_eq!(arr[IxDyn(&[0, 1, 1, 1])], 7.5); let output = dsperse::utils::io::arrayd_to_value(&arr); assert_eq!(output, input); } #[test] fn value_arrayd_full_roundtrip_preserves_values() { let original = make_value_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]); let arr = dsperse::utils::io::value_to_arrayd(&original).unwrap(); let reconstructed = dsperse::utils::io::arrayd_to_value(&arr); let arr2 = dsperse::utils::io::value_to_arrayd(&reconstructed).unwrap(); assert_eq!(arr.shape(), arr2.shape()); assert_eq!(arr, arr2); assert_eq!(original, reconstructed); } #[test] fn extract_input_data_key_precedence() { let val = Value::Map(vec![ (Value::String("input_data".into()), make_value_array(&[1.0])), (Value::String("input".into()), make_value_array(&[2.0])), (Value::String("data".into()), make_value_array(&[3.0])), (Value::String("inputs".into()), make_value_array(&[4.0])), ]); let extracted = dsperse::utils::io::extract_input_data(&val).unwrap(); assert_eq!(extracted, &make_value_array(&[1.0])); } #[test] fn extract_input_data_fallback_to_input() { let val = Value::Map(vec![ (Value::String("input".into()), make_value_array(&[2.0])), (Value::String("data".into()), make_value_array(&[3.0])), (Value::String("inputs".into()), make_value_array(&[4.0])), ]); let extracted = dsperse::utils::io::extract_input_data(&val).unwrap(); assert_eq!(extracted, &make_value_array(&[2.0])); } #[test] fn extract_input_data_fallback_to_data() { let val = Value::Map(vec![ (Value::String("data".into()), make_value_array(&[3.0])), (Value::String("inputs".into()), make_value_array(&[4.0])), ]); let extracted = dsperse::utils::io::extract_input_data(&val).unwrap(); assert_eq!(extracted, &make_value_array(&[3.0])); } #[test] fn extract_input_data_fallback_to_inputs() { let val = Value::Map(vec![( Value::String("inputs".into()), make_value_array(&[4.0]), )]); let extracted = dsperse::utils::io::extract_input_data(&val).unwrap(); assert_eq!(extracted, &make_value_array(&[4.0])); } #[test] fn extract_input_data_returns_none_for_unrecognized_keys() { let val = Value::Map(vec![ (Value::String("tensor".into()), make_value_array(&[1.0])), (Value::String("x".into()), make_value_array(&[2.0])), ]); assert!(dsperse::utils::io::extract_input_data(&val).is_none()); } #[test] fn slice_dir_path_formats_correctly() { let root = Path::new("/some/root"); assert_eq!( dsperse::utils::paths::slice_dir_path(root, 0), Path::new("/some/root/slice_0") ); assert_eq!( dsperse::utils::paths::slice_dir_path(root, 5), Path::new("/some/root/slice_5") ); assert_eq!( dsperse::utils::paths::slice_dir_path(root, 42), Path::new("/some/root/slice_42") ); } #[test] fn arrayd_to_value_then_extract_input_data_integration() { let arr = ArrayD::from_shape_vec(IxDyn(&[1, 3]), vec![1.0, 2.0, 3.0]).unwrap(); let tensor_val = dsperse::utils::io::arrayd_to_value(&arr); let wrapped = Value::Map(vec![(Value::String("input_data".into()), tensor_val)]); let extracted = dsperse::utils::io::extract_input_data(&wrapped).unwrap(); let roundtripped = dsperse::utils::io::value_to_arrayd(extracted).unwrap(); assert_eq!(arr.shape(), roundtripped.shape()); assert_eq!(arr, roundtripped); } ================================================ FILE: deny.toml ================================================ [graph] targets = [] all-features = false [advisories] yanked = "warn" [bans] multiple-versions = "warn" wildcards = "warn" [sources] unknown-registry = "deny" unknown-git = "warn" allow-git = [ "https://github.com/inference-labs-inc/JSTprove.git", ] ================================================ FILE: docs/JSTPROVE_BACKEND.md ================================================ # JSTprove Backend Integration ## Overview DSperse uses [JSTprove](https://github.com/inference-labs-inc/JSTprove) as its ZK proving backend. JSTprove is integrated as a Rust library dependency (`jstprove_circuits` crate) linked at compile time — there is no external CLI or Python process involved. DSperse is proving-system-agnostic. JSTprove currently provides two proof system backends selectable via the `--proof-system` flag: | Proof System | Description | |--------------|-------------| | `expander` (default) | Expander-based proving system | | `remainder` | Remainder-based proving system | ## Architecture The integration lives in `crates/dsperse/src/backend/jstprove.rs`, which wraps the `jstprove_circuits` crate. The `JstproveBackend` struct exposes the following operations that map directly to the pipeline stages: | Pipeline Stage | JSTprove Function | Description | |----------------|-------------------|-------------| | Compile | `compile_bn254` | Compiles an ONNX slice into a BN254 circuit (msgpack bundle) | | Witness | `witness_bn254` / `witness_bn254_from_f64` | Generates a witness from JSON or raw f64 inputs | | Prove | `prove_bn254` | Generates a proof from a compiled circuit and witness | | Verify | `verify_bn254` | Verifies a proof against a circuit and witness | | Extract | `extract_outputs_bn254` | Extracts model outputs from a witness | Circuit compilation produces a msgpack bundle containing the circuit, witness solver, and optional metadata (`CircuitParams`). All subsequent operations load this bundle via `read_circuit_msgpack`. ## Proof Pipeline Flow ```text ONNX slice | v compile_bn254 --> compiled circuit bundle (.msgpack) | v witness_bn254 --> witness bytes | v prove_bn254 --> proof bytes | v verify_bn254 --> bool (valid/invalid) ``` ## Proof System Selection The `--proof-system` flag is available on `slice`, `compile`, and `full-run` subcommands. Each proof system defines its own set of supported ONNX operations, queryable via `ProofSystem::supported_ops()`. The `--circuit-ops` flag allows restricting compilation to a subset of supported ops. ```bash dsperse slice --model-dir models/net --proof-system expander dsperse compile --model-dir models/net --proof-system remainder dsperse full-run --model-dir models/net --proof-system expander --circuit-ops "MatMul,Relu" ``` ## Dependency JSTprove is pulled in as a Cargo git dependency via the `jstprove_circuits` crate. No separate installation step is required — it is compiled into the dsperse binary. ================================================ FILE: docs/overview.md ================================================ # DSperse: Distributed zkML ## Overview DSperse is a proving-system-agnostic intelligent slicer for verifiable AI. It decomposes ONNX neural network models into circuit-compatible segments and orchestrates compilation, inference, proving, and verification across pluggable ZK backends. ### Core Purpose The project solves a significant challenge in zkML (zero-knowledge machine learning) by introducing intelligent model slicing that enables distributed proof computation across heterogeneous hardware. ### Key Technical Innovation The main innovation is the concept of "model slicing" where: 1. Instead of processing an entire neural network at once 2. The system splits the neural network into manageable segments 3. Each segment can be processed independently for analysis, inference, or proof generation ### Primary Goals 1. **Model Slicing** - Split neural network models into individual layers or custom segments - Support ONNX models - Enable detailed analysis of model components 2. **Distributed Computation** - Break down large ML models into manageable pieces - Enable parallel processing across multiple machines - Support both GPU and non-GPU nodes 3. **Resource Optimization** - Reduce RAM requirements through model splitting - Implement efficient inference pipelines - Better manage compute resources 4. **System Flexibility** - Support for different model types - Configurable slicing strategies - Adaptable to different hardware capabilities 5. **Zero-Knowledge Proofs** - Generate proofs for sliced model execution via JSTprove integration - Proving-system-agnostic design supporting Expander and Remainder backends - Optimize proof generation for distributed environments ### Implementation Framework - Built on top of existing tools: - ONNX for model representation and interoperability - JSTprove (`jstprove_circuits` Rust crate) for zero-knowledge proof generation - Expander and Remainder as the underlying proving systems - Comprehensive CLI interface for: - Model slicing - Inference - Proof generation - Proof verification - Designed to work with various neural network architectures - Focuses on practical applications of zkML technology ================================================ FILE: docs/uv_packaging.md ================================================ # Developer Guide This document provides a guide for developers who contribute to the project. ## Build System The project uses [maturin](https://www.maturin.rs/) as its build backend. The native Rust extension is compiled via PyO3 and exposed as `dsperse._native`. There are no Python-level dependencies beyond the compiled extension itself. The build configuration in `pyproject.toml`: ```toml [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" [tool.maturin] features = ["python"] module-name = "dsperse._native" python-source = "python" manifest-path = "crates/dsperse/Cargo.toml" ``` ## Local Development Create a virtual environment and build the extension in development mode: ```sh uv venv source .venv/bin/activate maturin develop --features python ``` This compiles the Rust crate and installs the resulting native extension into the active virtualenv. Re-run `maturin develop` after any Rust code changes. ## Building a Wheel ```sh maturin build --release --features python ``` The output wheel is self-contained with no additional Python dependencies. ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" [project] name = "dsperse" version = "0.0.0" description = "Distributed zkML Toolkit" readme = "README.md" requires-python = ">=3.9" license = { file = "LICENSE" } authors = [{ name = "Inference Labs", email = "info@inferencelabs.com" }] [project.scripts] dsperse = "dsperse.cli:main" [tool.maturin] features = ["python"] module-name = "dsperse._native" python-source = "python" manifest-path = "crates/dsperse/Cargo.toml" ================================================ FILE: python/dsperse/__init__.py ================================================ from dsperse._native import ( slice_model, compile_slices, run_inference, prove_run, verify_run, setup_holographic, ) __all__ = [ "slice_model", "compile_slices", "run_inference", "prove_run", "verify_run", "setup_holographic", ] ================================================ FILE: python/dsperse/cli.py ================================================ import sys def main(): try: from dsperse._native import cli_main except ImportError: print("dsperse native extension not found; install with: pip install dsperse", file=sys.stderr) return 1 try: cli_main() except SystemExit: raise except Exception as e: # noqa: BLE001 - top-level CLI wrapper to convert any error to exit code 1 print(f"error: {e}", file=sys.stderr) return 1 return 0 if __name__ == "__main__": raise SystemExit(main()) ================================================ FILE: rust-toolchain.toml ================================================ [toolchain] channel = "nightly-2026-02-22" components = ["clippy", "rustfmt"]