[
  {
    "path": ".cargo/audit.toml",
    "content": "[advisories]\nignore = [\n    \"RUSTSEC-2026-0009\", # time crate DoS via RFC 2822 parsing — transitive dep, not user-facing\n]\n"
  },
  {
    "path": ".cargo/config.toml",
    "content": "[net]\ngit-fetch-with-cli = true\n"
  },
  {
    "path": ".github/workflows/integration_tests.yml",
    "content": "name: Integration Tests\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\npermissions:\n  contents: read\n\njobs:\n  fmt:\n    name: Rustfmt\n    runs-on: ubuntu-latest\n    timeout-minutes: 10\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2\n      - name: Install Rust toolchain\n        run: rustup show\n      - run: cargo fmt --check\n\n  test:\n    name: Rust Tests\n    runs-on: ubuntu-latest\n    timeout-minutes: 45\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2\n        with:\n          submodules: true\n\n      - name: Install Rust toolchain\n        run: rustup show\n\n      - name: Install protoc\n        run: sudo apt-get update && sudo apt-get install -y protobuf-compiler\n\n      - uses: Swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32 # v2.9.1\n\n      - name: Test\n        run: cargo test --locked --manifest-path crates/dsperse/Cargo.toml\n\n      - name: Test (with python feature)\n        run: cargo test --locked --manifest-path crates/dsperse/Cargo.toml --features python\n\n      - name: Clippy\n        run: cargo clippy --locked --manifest-path crates/dsperse/Cargo.toml --all-targets --features python -- -D warnings\n\n  audit:\n    name: Security audit\n    runs-on: ubuntu-latest\n    timeout-minutes: 10\n    permissions:\n      contents: read\n      checks: write\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2\n      - run: rm -f rust-toolchain.toml && rustup install stable && rustup default stable\n      - uses: rustsec/audit-check@69366f33c96575abad1ee0dba8212993eecbe998 # v2.0.0\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n  deny:\n    name: Cargo deny\n    runs-on: ubuntu-latest\n    timeout-minutes: 10\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2\n      - uses: EmbarkStudios/cargo-deny-action@3fd3802e88374d3fe9159b834c7714ec57d6c979 # v2.0.15\n        with:\n          command: check bans sources\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "name: Build and Publish to PyPI\n\non:\n  push:\n    tags:\n      - \"v*\"\n  pull_request:\n  workflow_dispatch:\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\nenv:\n  UV_VERSION: \"0.10.8\"\n  MATURIN_VERSION: \"1.12.6\"\n\njobs:\n  build-linux:\n    if: >-\n      github.event_name != 'pull_request' ||\n      contains(github.event.pull_request.labels.*.name, 'test-build')\n    runs-on: ubuntu-latest\n    timeout-minutes: 60\n    container:\n      image: quay.io/pypa/manylinux_2_28_x86_64\n    permissions:\n      contents: read\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2\n\n      - name: Set up Python\n        run: echo \"/opt/python/cp312-cp312/bin\" >> $GITHUB_PATH\n\n      - name: Install uv\n        uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7\n        with:\n          version: ${{ env.UV_VERSION }}\n\n      - name: Install system dependencies\n        run: |\n          dnf install -y protobuf-compiler protobuf-devel pkgconf-pkg-config perl-IPC-Cmd perl-Time-Piece clang-devel\n\n      - name: Install Rust\n        run: |\n          curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly-2025-03-27\n          echo \"$HOME/.cargo/bin\" >> $GITHUB_PATH\n\n      - name: Extract version\n        id: get_version\n        shell: bash\n        run: |\n          if [[ \"$GITHUB_REF\" == refs/tags/v* ]]; then\n            VERSION=${GITHUB_REF#refs/tags/v}\n          else\n            VERSION=$(grep -m1 '^version' pyproject.toml | sed 's/.*\"\\(.*\\)\".*/\\1/')\n          fi\n          echo \"version=$VERSION\" >> $GITHUB_OUTPUT\n\n      - name: Update versions\n        run: |\n          sed -i '0,/^version = \".*\"/{s/^version = \".*\"/version = \"${{ steps.get_version.outputs.version }}\"/}' pyproject.toml\n          sed -i '0,/^version = \".*\"/{s/^version = \".*\"/version = \"${{ steps.get_version.outputs.version }}\"/}' crates/dsperse/Cargo.toml\n\n      - name: Cache Rust dependencies\n        uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n            target\n          key: manylinux-2-28-cargo-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: |\n            manylinux-2-28-cargo-\n\n      - name: Build wheel\n        run: uvx maturin==${{ env.MATURIN_VERSION }} build --release --manylinux 2_28 -i /opt/python/cp312-cp312/bin/python3\n\n      - name: Test wheel installation\n        run: |\n          uv pip install --system --python python3 target/wheels/*.whl\n          python3 -c \"from dsperse import slice_model; print('PyO3 bindings OK')\"\n\n      - name: Upload wheel artifact\n        uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0\n        with:\n          name: wheel-ubuntu-x86_64\n          path: ./target/wheels/*.whl\n\n  build-macos:\n    if: >-\n      github.event_name != 'pull_request' ||\n      contains(github.event.pull_request.labels.*.name, 'test-build')\n    runs-on: macos-latest\n    timeout-minutes: 60\n    permissions:\n      contents: read\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2\n\n      - name: Set up Python\n        uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0\n        with:\n          python-version: \"3.12\"\n\n      - name: Install uv\n        uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7\n        with:\n          version: ${{ env.UV_VERSION }}\n\n      - name: Install system dependencies\n        run: brew install protobuf llvm\n\n      - name: Extract version\n        id: get_version\n        shell: bash\n        run: |\n          if [[ \"$GITHUB_REF\" == refs/tags/v* ]]; then\n            VERSION=${GITHUB_REF#refs/tags/v}\n          else\n            VERSION=$(grep -m1 '^version' pyproject.toml | sed 's/.*\"\\(.*\\)\".*/\\1/')\n          fi\n          echo \"version=$VERSION\" >> $GITHUB_OUTPUT\n\n      - name: Update versions\n        run: |\n          sed -i '' '1,/^version = /{s/^version = \".*\"/version = \"${{ steps.get_version.outputs.version }}\"/;}' pyproject.toml\n          sed -i '' '1,/^version = /{s/^version = \".*\"/version = \"${{ steps.get_version.outputs.version }}\"/;}' crates/dsperse/Cargo.toml\n\n      - name: Install Rust\n        uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable 2026-02-13\n        with:\n          toolchain: nightly-2025-03-27\n\n      - name: Install Rust target\n        run: rustup target add aarch64-apple-darwin\n\n      - name: Cache Rust dependencies\n        uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4\n        with:\n          path: |\n            ~/.cargo/registry\n            ~/.cargo/git\n            target\n          key: macos-cargo-${{ hashFiles('**/Cargo.lock') }}\n          restore-keys: |\n            macos-cargo-\n\n      - name: Build wheel\n        run: uvx maturin==${{ env.MATURIN_VERSION }} build --release --target aarch64-apple-darwin\n        env:\n          MACOSX_DEPLOYMENT_TARGET: \"11.0\"\n\n      - name: Test wheel installation\n        run: |\n          uv pip install --system --python python3 target/wheels/*.whl\n          python3 -c \"from dsperse import slice_model; print('PyO3 bindings OK')\"\n\n      - name: Upload wheel artifact\n        uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0\n        with:\n          name: wheel-macos-aarch64\n          path: ./target/wheels/*.whl\n\n  publish:\n    if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')\n    needs: [build-linux, build-macos]\n    runs-on: ubuntu-latest\n    timeout-minutes: 15\n    permissions:\n      contents: write\n      id-token: write\n    steps:\n      - name: Download all wheels\n        uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1\n        with:\n          pattern: wheel-*\n          merge-multiple: true\n          path: ./dist\n\n      - name: Extract version from tag\n        id: get_version\n        shell: bash\n        run: |\n          VERSION=${GITHUB_REF#refs/tags/v}\n          echo \"version=$VERSION\" >> $GITHUB_OUTPUT\n\n      - name: Create GitHub Release with wheels\n        uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2.6.1\n        with:\n          name: Release ${{ steps.get_version.outputs.version }}\n          files: ./dist/*.whl\n\n      - name: Install uv\n        uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7\n        with:\n          version: ${{ env.UV_VERSION }}\n\n      - name: Publish to PyPI\n        run: uv publish ./dist/*\n"
  },
  {
    "path": ".gitignore",
    "content": "# macOS system files\n.DS_Store\n.DS_*\ntests/models/run\n# macOS metadata\n._*\n\n# Python cache\n__pycache__/\n*.py[cod]\n\n# Environment files\n.env\n.venv/\nenv/\nvenv/\n\n# IDE/editor folders\n.vscode/\n.idea/\n\n# Log files\n*.log\n\n# Byte-compiled\n*.pyo\n\n# Jupyter Notebook checkpoints\n.ipynb_checkpoints/\n\n# Python egg artifacts\n*.egg\n*.egg-info/\ndist/\nbuild/\neggs/\nparts/\nbin/\nvar/\nsdist/\ndevelop-eggs/\n.installed.cfg\n\n# ignore the models we test with\n*/models/*/slices\n*/src/models/*/slices/\n*/models/*/model_metadata.json\n*/src/models/*/model_metadata.json\n*/models/*/analysis/model_metadata.json\n*/src/models/*/analysis/model_metadata.json\n*/models/*/run\n*/src/models/*/run/\n*/models/*/input.json\n*/src/models/*/input.json\n*/models/*/*.onnx\n*/src/models/*/*.onnx\n*/models/*/*.dsperse\n*/src/models/*/*.dsperse\n*/models/*/*.data\n*/src/models/*/*.data\n\n\n# Local virtual envs\npython.venv/\n.venv/\nvenv/\n\n# Slice output directories\npitch-sliced/\n*-sliced/\n\n# Test output\ntests/models/output/\n/target\n/crates/*/target\n"
  },
  {
    "path": "Cargo.toml",
    "content": "[workspace]\nmembers = [\"crates/dsperse\"]\nresolver = \"2\"\n\n[workspace.package]\nedition = \"2024\"\n\n[workspace.dependencies]\nserde = { version = \"1\", features = [\"derive\"] }\nrmpv = { version = \"1\", features = [\"with-serde\"] }\nrmp-serde = \"1\"\nthiserror = \"2\"\nclap = { version = \"4\", features = [\"derive\", \"env\"] }\ntracing = \"0.1\"\ntracing-subscriber = { version = \"0.3\", features = [\"env-filter\"] }\nrayon = \"1\"\nndarray = { version = \"0.17\", features = [\"serde\"] }\ntract-onnx = { git = \"https://github.com/inference-labs-inc/tract.git\", rev = \"3cfae7f7\" }\nuuid = { version = \"1\", features = [\"v4\"] }\nsha2 = \"0.10\"\ntempfile = \"3\"\nprost = \"0.13\"\npyo3 = { version = \"0.24\" }\njstprove_circuits = { git = \"https://github.com/inference-labs-inc/JSTprove.git\", rev = \"87a1859f3487cf0fb9a463dbfd713b1df4827afc\" }\njstprove_io = { git = \"https://github.com/inference-labs-inc/JSTprove.git\", rev = \"87a1859f3487cf0fb9a463dbfd713b1df4827afc\", package = \"jstprove-io\" }\nreqwest = { version = \"0.12\", default-features = false, features = [\"rustls-tls\", \"json\"] }\ntokio = { version = \"1\", features = [\"rt\", \"macros\"] }\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2025 Inference Labs Inc.\n\nSource Access Grant\nYou may access, view, study, and modify the source code of this software.\n\nRedistribution Conditions\nYou may redistribute this software in source or modified form provided that:\na) You retain this license document and all copyright notices\nb) Any modified files carry prominent notices stating you changed them\nc) You do not misrepresent the origin of the software\n\nUsage Restriction\nNO USE RIGHTS ARE GRANTED BY THIS LICENSE. Any operational use including but not limited to:\n- Execution of the software\n- Integration with other systems\n- Deployment in any environment\n- Commercial or production utilization requires express written permission from the IP Owner.\n\nIntellectual Property Reservation\nAll rights not expressly granted herein are reserved by the IP Owner. For usage permissions, contact: legal@inferencelabs.com\n\nDisclaimer\nTHIS SOFTWARE IS PROVIDED \"AS IS\" WITHOUT WARRANTY OF ANY KIND. THE IP OWNER SHALL NOT BE LIABLE FOR ANY DAMAGES ARISING FROM ACCESS OR DISTRIBUTION.\n\nLicense Propagation\nAny distribution of this software or derivatives must be under this same license agreement."
  },
  {
    "path": "README.md",
    "content": "# DSperse: Community Edition\n\n[![GitHub](https://img.shields.io/badge/GitHub-Repository-blue?style=flat-square&logo=github)](https://github.com/inference-labs-inc/dsperse)\n[![Discord](https://img.shields.io/badge/Discord-Join%20Community-7289DA?style=flat-square&logo=discord)](https://discord.gg/GBxBCWJs)\n[![Telegram](https://img.shields.io/badge/Telegram-Join%20Channel-0088cc?style=flat-square&logo=telegram)](https://t.me/inference_labs)\n[![Twitter](https://img.shields.io/badge/Twitter-Follow%20Us-1DA1F2?style=flat-square&logo=twitter)](https://x.com/inference_labs)\n[![Website](https://img.shields.io/badge/Website-Visit%20Us-ff7139?style=flat-square&logo=firefox-browser)](https://inferencelabs.com)\n[![Whitepaper](https://img.shields.io/badge/Whitepaper-Read-lightgrey?style=flat-square&logo=read-the-docs)](http://arxiv.org/abs/2508.06972)\n\nDSperse is a proving-system-agnostic intelligent slicer for verifiable AI. It decomposes ONNX neural network models into circuit-compatible segments and orchestrates compilation, inference, proving, and verification across pluggable ZK backends.\n\n## Features\n\n- **Model Slicing**: Split neural network models into individual layers or custom segments\n- **ONNX Support**: Slice and orchestrate ONNX models\n- **Layered Inference**: Run inference on sliced models, chaining the output of each segment\n- **Zero-Knowledge Proofs**: Generate and verify proofs for model execution via JSTprove\n- **Tiling and Channel Splitting**: Automatically decompose large convolutions for circuit-compatible execution\n- **Proof System Agnostic**: Pluggable backend architecture supporting Expander and Remainder proof systems\n\n## Documentation\n\n- [Overview](docs/overview.md): High-level overview of the project, its goals, and features\n- [JSTprove Backend](docs/JSTPROVE_BACKEND.md): JSTprove integration and usage\n\n## Installation\n\n### From PyPI (includes CLI)\n\n```bash\npip install dsperse\n```\n\nThis installs both the `dsperse` CLI command and the Python library bindings. No additional dependencies required — everything is compiled into a single native extension.\n\n### From source (Rust binary)\n\n```bash\ncargo install --path crates/dsperse\n```\n\n### As a Rust library\n\n```toml\n[dependencies]\ndsperse = { git = \"https://github.com/inference-labs-inc/dsperse.git\" }\n```\n\n## CLI Usage\n\nDSperse provides six subcommands that form a complete pipeline:\n\n| Command | Description |\n|---------|-------------|\n| `slice` | Split an ONNX model into segments |\n| `compile` | Compile slices into ZK circuits |\n| `run` | Execute chained inference across slices (`--weights` to inject consumer ONNX) |\n| `prove` | Generate ZK proofs for a completed run |\n| `verify` | Verify ZK proofs |\n| `full-run` | Execute compile, run, prove, verify in sequence (supports `--weights`) |\n\n### Quickstart\n\n```bash\ndsperse slice --model-dir models/net\ndsperse compile --model-dir models/net --parallel 4\ndsperse run --model-dir models/net --input-file models/net/input.json\ndsperse prove --model-dir models/net --run-dir models/net/run/run_*\ndsperse verify --model-dir models/net --run-dir models/net/run/run_*\n```\n\nOr run the entire pipeline at once:\n\n```bash\ndsperse full-run --model-dir models/net --input-file models/net/input.json\n```\n\nTo inject consumer weights from a fine-tuned ONNX model (same architecture, different weights):\n\n```bash\ndsperse run --model-dir models/net --input-file models/net/input.json --weights path/to/consumer.onnx\ndsperse full-run --model-dir models/net --input-file models/net/input.json --weights path/to/consumer.onnx\n```\n\n## Python Library Usage\n\n```python\nimport dsperse\n\nmetadata_json = dsperse.slice_model(\"models/net/model.onnx\", output_dir=\"models/net/slices\")\ndsperse.compile_slices(\"models/net/slices\", parallel=4)\nrun_json = dsperse.run_inference(\"models/net/slices\", \"models/net/input.json\", \"models/net/run\")\nproof_json = dsperse.prove_run(\"models/net/run\", \"models/net/slices\")\nverify_json = dsperse.verify_run(\"models/net/run\", \"models/net/slices\")\n```\n\nTo inject consumer weights at inference time, pass `weights_onnx` (path to a fine-tuned ONNX with the same architecture):\n\n```python\nrun_json = dsperse.run_inference(\n    \"models/net/slices\", \"models/net/input.json\", \"models/net/run\",\n    weights_onnx=\"path/to/consumer.onnx\",\n)\n```\n\n`slice_model`, `run_inference`, `prove_run`, and `verify_run` return JSON strings parseable with `json.loads()`. `compile_slices` returns `None`.\n\n## Project Structure\n\n```text\ncrates/dsperse/\n  src/\n    cli/          CLI argument parsing and command dispatch\n    slicer/       ONNX model analysis, slicing, autotiling, channel splitting\n    pipeline/     Compilation, inference, proving, verification orchestration\n    backend/      JSTprove backend integration\n    schema/       Metadata and execution result types (serde)\n    converter.rs  Prepares JSTprove artifacts from ONNX files\n    utils/        I/O helpers and path resolution\n  tests/          Unit and integration tests\npython/           Thin Python wrapper for PyO3 bindings\n```\n\n## Contributing\n\nContributions are welcome. Please open issues and PRs on GitHub.\n\n## License\n\nSee the [LICENSE](LICENSE) file for details.\n"
  },
  {
    "path": "crates/dsperse/Cargo.toml",
    "content": "[package]\nname = \"dsperse\"\nversion = \"0.0.0\"\nedition.workspace = true\n\n[features]\ndefault = []\npython = [\"dep:pyo3\", \"pyo3/extension-module\"]\n\n[dependencies]\nserde.workspace = true\nrmpv.workspace = true\nrmp-serde.workspace = true\nthiserror.workspace = true\nclap.workspace = true\ntracing.workspace = true\ntracing-subscriber.workspace = true\nrayon.workspace = true\nndarray.workspace = true\ntract-onnx.workspace = true\nuuid.workspace = true\nsha2.workspace = true\ntempfile.workspace = true\nprost.workspace = true\npyo3 = { workspace = true, optional = true }\nserde_json = \"1\"\nzip = { version = \"2\", default-features = false, features = [\"deflate\"] }\nwalkdir = \"2\"\njstprove_circuits.workspace = true\njstprove_io.workspace = true\nreqwest.workspace = true\ntokio.workspace = true\n\n[target.'cfg(unix)'.dependencies]\nlibc = \"0.2\"\n\n[build-dependencies]\nprost-build = \"0.13\"\n\n[dev-dependencies]\ncriterion = { version = \"0.5\", features = [\"html_reports\"] }\n\n[[bench]]\nname = \"serialization\"\nharness = false\n\n[lib]\nname = \"dsperse\"\ncrate-type = [\"cdylib\", \"lib\"]\n"
  },
  {
    "path": "crates/dsperse/benches/serialization.rs",
    "content": "use std::collections::HashMap;\n\nuse criterion::{Criterion, black_box, criterion_group, criterion_main};\nuse dsperse::schema::execution::{\n    ExecutionChain, ExecutionInfo, ExecutionMethod, ExecutionNode, ExecutionResultEntry,\n    RunMetadata, SliceResult, TileResult,\n};\nuse dsperse::schema::metadata::{\n    BackendKind, Compilation, Dependencies, ModelMetadata, RunSliceMetadata, SliceMetadata,\n    SliceShapeWrapper, TensorShape,\n};\nuse serde::{Deserialize, Serialize};\n\nfn make_slice_metadata(index: usize) -> SliceMetadata {\n    SliceMetadata {\n        index,\n        filename: format!(\"slice_{index}.onnx\"),\n        path: format!(\"/tmp/slices/slice_{index}/payload/slice_{index}.onnx\"),\n        relative_path: format!(\"slice_{index}/payload/slice_{index}.onnx\"),\n        shape: SliceShapeWrapper {\n            tensor_shape: TensorShape {\n                input: vec![vec![1, 3, 224, 224]],\n                output: vec![vec![1, 64, 112, 112]],\n            },\n        },\n        dependencies: Dependencies {\n            input: vec![format!(\"input_{index}\")],\n            output: vec![format!(\"output_{index}\")],\n            filtered_inputs: vec![format!(\"input_{index}\")],\n        },\n        tiling: None,\n        channel_split: None,\n        dim_split: None,\n        compilation: Compilation::default(),\n        slice_metadata: Some(format!(\"slice_{index}/metadata.msgpack\")),\n        slice_metadata_relative_path: Some(format!(\"slice_{index}/metadata.msgpack\")),\n    }\n}\n\nfn make_model_metadata(num_slices: usize) -> ModelMetadata {\n    let slices: Vec<SliceMetadata> = (0..num_slices).map(make_slice_metadata).collect();\n    let slice_points: Vec<usize> = (0..=num_slices).collect();\n    ModelMetadata {\n        original_model: \"/tmp/model.onnx\".into(),\n        model_type: \"ONNX\".into(),\n        input_shape: vec![vec![1, 3, 224, 224]],\n        output_shapes: vec![vec![1, 1000]],\n        output_names: vec![\"output\".into()],\n        slice_points,\n        slices,\n        dsperse_version: Some(\"0.0.0\".into()),\n        dsperse_rev: Some(\"abc1234\".into()),\n        jstprove_version: Some(\"0.1.0\".into()),\n        jstprove_rev: Some(\"def5678\".into()),\n        traced_shapes: None,\n        traced_types: None,\n        original_model_path: None,\n        folded_constant_names: vec![],\n    }\n}\n\nfn make_run_metadata(num_slices: usize) -> RunMetadata {\n    let mut slices = HashMap::new();\n    let mut nodes = HashMap::new();\n    let mut execution_results = Vec::new();\n\n    for i in 0..num_slices {\n        let slice_id = format!(\"slice_{i}\");\n        slices.insert(\n            slice_id.clone(),\n            RunSliceMetadata {\n                path: format!(\"slice_{i}/payload/slice_{i}.onnx\"),\n                input_shape: vec![vec![1, 3, 224, 224]],\n                output_shape: vec![vec![1, 64, 112, 112]],\n                dependencies: Dependencies {\n                    input: vec![format!(\"input_{i}\")],\n                    output: vec![format!(\"output_{i}\")],\n                    filtered_inputs: vec![format!(\"input_{i}\")],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                backend: BackendKind::Jstprove,\n                jstprove_circuit_path: Some(format!(\"slice_{i}/jstprove/circuit.bin\")),\n                jstprove_settings_path: None,\n            },\n        );\n        nodes.insert(\n            slice_id.clone(),\n            ExecutionNode {\n                slice_id: slice_id.clone(),\n                primary: Some(\"jstprove_gen_witness\".into()),\n                fallbacks: vec![\"onnx_only\".into()],\n                use_circuit: true,\n                next: if i + 1 < num_slices {\n                    Some(format!(\"slice_{}\", i + 1))\n                } else {\n                    None\n                },\n                circuit_path: Some(format!(\"slice_{i}/jstprove/circuit.bin\")),\n                onnx_path: Some(format!(\"slice_{i}/payload/slice_{i}.onnx\")),\n                backend: BackendKind::Jstprove,\n            },\n        );\n        execution_results.push(ExecutionResultEntry {\n            slice_id: slice_id.clone(),\n            witness_execution: Some(ExecutionInfo {\n                method: ExecutionMethod::JstproveGenWitness,\n                success: true,\n                error: None,\n                witness_file: Some(format!(\"runs/run_0/{slice_id}/witness.bin\")),\n                tile_exec_infos: vec![TileResult {\n                    tile_idx: 0,\n                    success: true,\n                    error: None,\n                    method: Some(ExecutionMethod::JstproveGenWitness),\n                    time_sec: 1.23,\n                    proof_path: None,\n                }],\n            }),\n            proof_execution: Some(SliceResult {\n                slice_id: slice_id.clone(),\n                success: true,\n                method: Some(ExecutionMethod::JstproveProve),\n                error: None,\n                proof_path: Some(format!(\"runs/run_0/{slice_id}/proof.bin\")),\n                time_sec: 45.67,\n                tiles: Vec::new(),\n            }),\n            verification_execution: None,\n        });\n    }\n\n    RunMetadata {\n        slices,\n        execution_chain: ExecutionChain {\n            head: Some(\"slice_0\".into()),\n            nodes,\n            fallback_map: HashMap::new(),\n            execution_results,\n            jstprove_proved_slices: num_slices,\n            jstprove_verified_slices: 0,\n        },\n        packaging_type: Some(\"dsperse\".into()),\n        source_path: Some(\"/tmp/model.onnx\".into()),\n        run_directory: Some(\"/tmp/runs/run_0\".into()),\n        model_path: Some(\"/tmp/model.onnx\".into()),\n    }\n}\n\nfn bench_roundtrip<T: Serialize + for<'de> Deserialize<'de>>(\n    c: &mut Criterion,\n    name: &str,\n    value: &T,\n) {\n    let json_bytes = serde_json::to_vec(value).unwrap();\n    let msgpack_bytes = rmp_serde::to_vec_named(value).unwrap();\n\n    let group_name = format!(\n        \"{name} (json={}, msgpack={})\",\n        json_bytes.len(),\n        msgpack_bytes.len()\n    );\n    let mut group = c.benchmark_group(&group_name);\n\n    group.bench_function(\"json_serialize\", |b| {\n        b.iter(|| serde_json::to_vec(black_box(value)).unwrap());\n    });\n    group.bench_function(\"msgpack_serialize\", |b| {\n        b.iter(|| rmp_serde::to_vec_named(black_box(value)).unwrap());\n    });\n    group.bench_function(\"json_deserialize\", |b| {\n        b.iter(|| serde_json::from_slice::<T>(black_box(&json_bytes)).unwrap());\n    });\n    group.bench_function(\"msgpack_deserialize\", |b| {\n        b.iter(|| rmp_serde::from_slice::<T>(black_box(&msgpack_bytes)).unwrap());\n    });\n\n    group.finish();\n}\n\nfn serialization_benchmarks(c: &mut Criterion) {\n    let small_model = make_model_metadata(4);\n    let large_model = make_model_metadata(64);\n    let small_run = make_run_metadata(4);\n    let large_run = make_run_metadata(64);\n\n    bench_roundtrip(c, \"ModelMetadata_4slices\", &small_model);\n    bench_roundtrip(c, \"ModelMetadata_64slices\", &large_model);\n    bench_roundtrip(c, \"RunMetadata_4slices\", &small_run);\n    bench_roundtrip(c, \"RunMetadata_64slices\", &large_run);\n}\n\ncriterion_group!(benches, serialization_benchmarks);\ncriterion_main!(benches);\n"
  },
  {
    "path": "crates/dsperse/build.rs",
    "content": "fn main() {\n    prost_build::Config::new()\n        .compile_protos(&[\"proto/onnx.proto\"], &[\"proto/\"])\n        .expect(\"Failed to compile ONNX proto\");\n\n    let git_rev = std::process::Command::new(\"git\")\n        .args([\"rev-parse\", \"--short\", \"HEAD\"])\n        .output()\n        .ok()\n        .filter(|o| o.status.success())\n        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string());\n\n    if let Some(ref rev) = git_rev {\n        println!(\"cargo:rustc-env=DSPERSE_GIT_REV={rev}\");\n    }\n\n    let pkg_version = std::env::var(\"CARGO_PKG_VERSION\").unwrap();\n    let display_version = match (pkg_version.as_str(), &git_rev) {\n        (\"0.0.0\", Some(rev)) => format!(\"dev-{rev}\"),\n        (\"0.0.0\", None) => \"dev\".to_string(),\n        (v, Some(rev)) => format!(\"{v}+{rev}\"),\n        (v, None) => v.to_string(),\n    };\n    println!(\"cargo:rustc-env=DSPERSE_DISPLAY_VERSION={display_version}\");\n    if let Some(output) = std::process::Command::new(\"git\")\n        .args([\"rev-parse\", \"--git-path\", \"HEAD\"])\n        .output()\n        .ok()\n        .filter(|o| o.status.success())\n    {\n        let head_path = String::from_utf8_lossy(&output.stdout).trim().to_string();\n        println!(\"cargo:rerun-if-changed={head_path}\");\n    }\n\n    if let Some(output) = std::process::Command::new(\"git\")\n        .args([\"symbolic-ref\", \"-q\", \"HEAD\"])\n        .output()\n        .ok()\n        .filter(|o| o.status.success())\n    {\n        let head_ref = String::from_utf8_lossy(&output.stdout).trim().to_string();\n        if let Some(output) = std::process::Command::new(\"git\")\n            .args([\"rev-parse\", \"--git-path\", &head_ref])\n            .output()\n            .ok()\n            .filter(|o| o.status.success())\n        {\n            let ref_path = String::from_utf8_lossy(&output.stdout).trim().to_string();\n            println!(\"cargo:rerun-if-changed={ref_path}\");\n        }\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/proto/onnx.proto",
    "content": "//\n// WARNING: This file is automatically generated!  Please edit onnx.in.proto.\n//\n\n\n// SPDX-License-Identifier: Apache-2.0\n\n\nsyntax = \"proto3\";\n\npackage onnx;\n\n// Overview\n//\n// ONNX is an open specification that is comprised of the following components:\n//\n// 1)  A definition of an extensible computation graph model.\n// 2)  Definitions of standard data types.\n// 3)  Definitions of built-in operators.\n//\n// This document describes the syntax of models and their computation graphs,\n// as well as the standard data types. Together, they are referred to as the ONNX\n// Intermediate Representation, or 'IR' for short.\n//\n// The normative semantic specification of the ONNX IR is found in docs/IR.md.\n// Definitions of the built-in neural network operators may be found in docs/Operators.md.\n\n// Notes\n//\n// Protobuf compatibility\n//\n// To simplify framework compatibility, ONNX is defined using the subset of protobuf\n// that is compatible with both protobuf v2 and v3. This means that we do not use any\n// protobuf features that are only available in one of the two versions.\n//\n// Here are the most notable contortions we have to carry out to work around\n// these limitations:\n//\n//   - No 'map' (added protobuf 3.0). We instead represent mappings as lists\n//     of key-value pairs, where order does not matter and duplicates\n//     are not allowed.\n\n\n// Versioning\n//\n// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md\n//\n// To be compatible with both proto2 and proto3, we will use a version number\n// that is not defined by the default value but an explicit enum number.\nenum Version {\n  // proto3 requires the first enum value to be zero.\n  // We add this just to appease the compiler.\n  _START_VERSION = 0;\n  // The version field is always serialized and we will use it to store the\n  // version that the  graph is generated from. This helps us set up version\n  // control.\n  // For the IR, we are using simple numbers starting with 0x00000001,\n  // which was the version we published on Oct 10, 2017.\n  IR_VERSION_2017_10_10 = 0x0000000000000001;\n\n  // IR_VERSION 2 published on Oct 30, 2017\n  // - Added type discriminator to AttributeProto to support proto3 users\n  IR_VERSION_2017_10_30 = 0x0000000000000002;\n\n  // IR VERSION 3 published on Nov 3, 2017\n  // - For operator versioning:\n  //    - Added new message OperatorSetIdProto\n  //    - Added opset_import in ModelProto\n  // - For vendor extensions, added domain in NodeProto\n  IR_VERSION_2017_11_3 = 0x0000000000000003;\n\n  // IR VERSION 4 published on Jan 22, 2019\n  // - Relax constraint that initializers should be a subset of graph inputs\n  // - Add type BFLOAT16\n  IR_VERSION_2019_1_22 = 0x0000000000000004;\n\n  // IR VERSION 5 published on March 18, 2019\n  // - Add message TensorAnnotation.\n  // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.\n  IR_VERSION_2019_3_18 = 0x0000000000000005;\n\n  // IR VERSION 6 published on Sep 19, 2019\n  // - Add support for sparse tensor constants stored in model.\n  //   - Add message SparseTensorProto\n  //   - Add sparse initializers\n  IR_VERSION_2019_9_19 = 0x0000000000000006;\n\n  // IR VERSION 7 published on May 8, 2020\n  // - Add support to allow function body graph to rely on multiple external opreator sets.\n  // - Add a list to promote inference graph's initializers to global and\n  //   mutable variables. Global variables are visible in all graphs of the\n  //   stored models.\n  // - Add message TrainingInfoProto to store initialization\n  //   method and training algorithm. The execution of TrainingInfoProto\n  //   can modify the values of mutable variables.\n  // - Implicitly add inference graph into each TrainingInfoProto's algorithm.\n  IR_VERSION_2020_5_8 = 0x0000000000000007;\n\n  // IR VERSION 8 published on July 30, 2021\n  // Introduce TypeProto.SparseTensor\n  // Introduce TypeProto.Optional\n  // Added a list of FunctionProtos local to the model\n  // Deprecated since_version and operator status from FunctionProto\n  IR_VERSION_2021_7_30 = 0x0000000000000008;\n\n  // IR VERSION 9 published on May 5, 2023\n  // Added AttributeProto to FunctionProto so that default attribute values can be set.\n  // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.\n  IR_VERSION_2023_5_5 = 0x0000000000000009;\n\n  // IR VERSION 10 published on March 25, 2024\n  // Added UINT4, INT4.\n  IR_VERSION_2024_3_25 = 0x000000000000000A;\n\n  // IR VERSION 11 published on TBD\n  // Added FLOAT4E2M1, multi-device protobuf classes.\n  IR_VERSION = 0x000000000000000B;\n}\n\n// Attributes\n//\n// A named attribute containing either singular float, integer, string, graph,\n// and tensor values, or repeated float, integer, string, graph, and tensor values.\n// An AttributeProto MUST contain the name field, and *only one* of the\n// following content fields, effectively enforcing a C/C++ union equivalent.\nmessage AttributeProto {\n  reserved 12, 16 to 19;\n  reserved \"v\";\n\n  // Note: this enum is structurally identical to the OpSchema::AttrType\n  // enum defined in schema.h.  If you rev one, you likely need to rev the other.\n  enum AttributeType {\n    UNDEFINED = 0;\n    FLOAT = 1;\n    INT = 2;\n    STRING = 3;\n    TENSOR = 4;\n    GRAPH = 5;\n    SPARSE_TENSOR = 11;\n    TYPE_PROTO = 13;\n\n    FLOATS = 6;\n    INTS = 7;\n    STRINGS = 8;\n    TENSORS = 9;\n    GRAPHS = 10;\n    SPARSE_TENSORS = 12;\n    TYPE_PROTOS = 14;\n  }\n\n  // The name field MUST be present for this version of the IR.\n  string name = 1;           // namespace Attribute\n\n  // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.\n  // In this case, this AttributeProto does not contain data, and it's a reference of attribute\n  // in parent scope.\n  // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.\n  string ref_attr_name = 21;\n\n  // A human-readable documentation for this attribute. Markdown is allowed.\n  string doc_string = 13;\n\n  // The type field MUST be present for this version of the IR.\n  // For 0.0.1 versions of the IR, this field was not defined, and\n  // implementations needed to use has_field heuristics to determine\n  // which value field was in use.  For IR_VERSION 0.0.2 or later, this\n  // field MUST be set and match the f|i|s|t|... field in use.  This\n  // change was made to accommodate proto3 implementations.\n  AttributeType type = 20;   // discriminator that indicates which field below is in use\n\n  // Exactly ONE of the following fields must be present for this version of the IR\n  float f = 2;               // float\n  int64 i = 3;               // int\n  bytes s = 4;               // UTF-8 string\n  TensorProto t = 5;         // tensor value\n  GraphProto g = 6;          // graph\n  SparseTensorProto sparse_tensor = 22;  // sparse tensor value\n  // Do not use field below, it's deprecated.\n  // optional ValueProto v = 12;         // value - subsumes everything but graph\n  TypeProto tp = 14;          // type proto\n\n  repeated float floats = 7;          // list of floats\n  repeated int64 ints = 8;            // list of ints\n  repeated bytes strings = 9;         // list of UTF-8 strings\n  repeated TensorProto tensors = 10;  // list of tensors\n  repeated GraphProto graphs = 11;    // list of graph\n  repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors\n  repeated TypeProto type_protos = 15;// list of type protos\n}\n\n// Defines information on value, including the name, the type, and\n// the shape of the value.\nmessage ValueInfoProto {\n  // This field MUST be present in this version of the IR.\n  string name = 1;     // namespace Value\n  // This field MUST be present in this version of the IR for\n  // inputs and outputs of the top-level graph.\n  TypeProto type = 2;\n  // A human-readable documentation for this value. Markdown is allowed.\n  string doc_string = 3;\n  // Named metadata values; keys should be distinct.\n  repeated StringStringEntryProto metadata_props = 4;\n}\n\n// Nodes\n//\n// Computation graphs are made up of a DAG of nodes, which represent what is\n// commonly called a \"layer\" or \"pipeline stage\" in machine learning frameworks.\n//\n// For example, it can be a node of type \"Conv\" that takes in an image, a filter\n// tensor and a bias tensor, and produces the convolved output.\nmessage NodeProto {\n  repeated string input = 1;    // namespace Value\n  repeated string output = 2;   // namespace Value\n\n  // An optional identifier for this node in a graph.\n  // This field MAY be absent in this version of the IR.\n  string name = 3;     // namespace Node\n\n  // The symbolic identifier of the Operator to execute.\n  string op_type = 4;  // namespace Operator\n  // The domain of the OperatorSet that specifies the operator named by op_type.\n  string domain = 7;   // namespace Domain\n  // Overload identifier, used only to map this to a model-local function.\n  string overload = 8;\n\n  // Additional named attributes.\n  repeated AttributeProto attribute = 5;\n\n  // A human-readable documentation for this node. Markdown is allowed.\n  string doc_string = 6;\n\n  // Named metadata values; keys should be distinct.\n  repeated StringStringEntryProto metadata_props = 9;\n\n  // Configuration of multi-device annotations.\n  repeated NodeDeviceConfigurationProto device_configurations = 10;\n}\n\n// IntIntListEntryProto follows the pattern for cross-proto-version maps.\n// See https://developers.google.com/protocol-buffers/docs/proto3#maps\nmessage IntIntListEntryProto {\n  int64 key = 1;\n  repeated int64 value = 2;\n};\n\n// Multi-device configuration proto for NodeProto.\nmessage NodeDeviceConfigurationProto {\n    // This field MUST be present for this version of the IR.\n    // ID of the configuration. MUST match the name of a DeviceConfigurationProto.\n    string configuration_id = 1;\n    // Sharding spec for the node.\n    repeated ShardingSpecProto sharding_spec = 2;\n    // Pipeline stage of this node.\n    int32 pipeline_stage = 3;\n}\n\n// ShardingSpecProto: This describes the sharding spec for a specific\n// input or output tensor of a node.\nmessage ShardingSpecProto {\n  // This field MUST be present for this version of the IR.\n  // Identifies the input or output of the node that is being sharded.\n  // Required to match a name specified in the node's input or output list of ValueInfoProtos.\n  // It is called `logical tensor` in subsequent descriptions.\n  string tensor_name = 1;\n\n  // The following is the list of devices across which the logical\n  // tensor is sharded or replicated.\n  repeated int64 device = 2;\n\n  // Each element v in above field devices may represent either a\n  // device or a set of devices (when we want the same shard/tensor\n  // to be replicated across a subset of devices), as indicated by\n  // the following optional map. If the map contains an entry for v,\n  // then v represents a device group, and the map indicates the set\n  // of devices in that group.\n  repeated IntIntListEntryProto index_to_device_group_map = 3;\n\n  // The following is the sharded-shape of the tensor, consisting of\n  // the sharding-spec for each axis of the tensor.\n  repeated ShardedDimProto sharded_dim = 4;\n}\n\n// ShardedDimProto: This describes the sharding spec for a single\n// axis of a sharded tensor.\nmessage ShardedDimProto {\n  // This field MUST be present for this version of the IR.\n  // The axis this sharding corresponds to. Must be in the range of\n  // [-r, r - 1], where r is the rank of the tensor. Negative axis values means\n  // counting from the back.\n  int64 axis = 1;\n\n  // Describes how the tensor on the provided axis is sharded.\n  // The common-case is described by a single instance of SimpleShardedDimProto.\n  // Multiple instances can be used to handle cases where a sharded\n  // tensor is reshaped, fusing multiple axes into one.\n  repeated SimpleShardedDimProto simple_sharding = 2;\n}\n\n// SimpleShardedDimProto: Indicates that N blocks are divided into M shards.\n// N is allowed to be symbolic where M is required to be a constant.\nmessage SimpleShardedDimProto {\n    // Dimension value to be sharded.\n    oneof dim {\n        int64 dim_value = 1;\n        string dim_param = 2;\n    }\n\n    // This field MUST be present for this version of the IR.\n    // Number of shards to split dim into.\n    int64 num_shards = 3;\n}\n\n// Training information\n// TrainingInfoProto stores information for training a model.\n// In particular, this defines two functionalities: an initialization-step\n// and a training-algorithm-step. Initialization resets the model\n// back to its original state as if no training has been performed.\n// Training algorithm improves the model based on input data.\n//\n// The semantics of the initialization-step is that the initializers\n// in ModelProto.graph and in TrainingInfoProto.algorithm are first\n// initialized as specified by the initializers in the graph, and then\n// updated by the \"initialization_binding\" in every instance in\n// ModelProto.training_info.\n//\n// The field \"algorithm\" defines a computation graph which represents a\n// training algorithm's step. After the execution of a\n// TrainingInfoProto.algorithm, the initializers specified by \"update_binding\"\n// may be immediately updated. If the targeted training algorithm contains\n// consecutive update steps (such as block coordinate descent methods),\n// the user needs to create a TrainingInfoProto for each step.\nmessage TrainingInfoProto {\n  // This field describes a graph to compute the initial tensors\n  // upon starting the training process. Initialization graph has no input\n  // and can have multiple outputs. Usually, trainable tensors in neural\n  // networks are randomly initialized. To achieve that, for each tensor,\n  // the user can put a random number operator such as RandomNormal or\n  // RandomUniform in TrainingInfoProto.initialization.node and assign its\n  // random output to the specific tensor using \"initialization_binding\".\n  // This graph can also set the initializers in \"algorithm\" in the same\n  // TrainingInfoProto; a use case is resetting the number of training\n  // iteration to zero.\n  //\n  // By default, this field is an empty graph and its evaluation does not\n  // produce any output. Thus, no initializer would be changed by default.\n  GraphProto initialization = 1;\n\n  // This field represents a training algorithm step. Given required inputs,\n  // it computes outputs to update initializers in its own or inference graph's\n  // initializer lists. In general, this field contains loss node, gradient node,\n  // optimizer node, increment of iteration count.\n  //\n  // An execution of the training algorithm step is performed by executing the\n  // graph obtained by combining the inference graph (namely \"ModelProto.graph\")\n  // and the \"algorithm\" graph. That is, the actual\n  // input/initializer/output/node/value_info/sparse_initializer list of\n  // the training graph is the concatenation of\n  // \"ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer\"\n  // and \"algorithm.input/initializer/output/node/value_info/sparse_initializer\"\n  // in that order. This combined graph must satisfy the normal ONNX conditions.\n  // Now, let's provide a visualization of graph combination for clarity.\n  // Let the inference graph (i.e., \"ModelProto.graph\") be\n  //    tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d\n  // and the \"algorithm\" graph be\n  //    tensor_d -> Add -> tensor_e\n  // The combination process results\n  //    tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e\n  //\n  // Notice that an input of a node in the \"algorithm\" graph may reference the\n  // output of a node in the inference graph (but not the other way round). Also, inference\n  // node cannot reference inputs of \"algorithm\". With these restrictions, inference graph\n  // can always be run independently without training information.\n  //\n  // By default, this field is an empty graph and its evaluation does not\n  // produce any output. Evaluating the default training step never\n  // update any initializers.\n  GraphProto algorithm = 2;\n\n  // This field specifies the bindings from the outputs of \"initialization\" to\n  // some initializers in \"ModelProto.graph.initializer\" and\n  // the \"algorithm.initializer\" in the same TrainingInfoProto.\n  // See \"update_binding\" below for details.\n  //\n  // By default, this field is empty and no initializer would be changed\n  // by the execution of \"initialization\".\n  repeated StringStringEntryProto initialization_binding = 3;\n\n  // Gradient-based training is usually an iterative procedure. In one gradient\n  // descent iteration, we apply\n  //\n  // x = x - r * g\n  //\n  // where \"x\" is the optimized tensor, \"r\" stands for learning rate, and \"g\" is\n  // gradient of \"x\" with respect to a chosen loss. To avoid adding assignments\n  // into the training graph, we split the update equation into\n  //\n  // y = x - r * g\n  // x = y\n  //\n  // The user needs to save \"y = x - r * g\" into TrainingInfoProto.algorithm. To\n  // tell that \"y\" should be assigned to \"x\", the field \"update_binding\" may\n  // contain a key-value pair of strings, \"x\" (key of StringStringEntryProto)\n  // and \"y\" (value of StringStringEntryProto).\n  // For a neural network with multiple trainable (mutable) tensors, there can\n  // be multiple key-value pairs in \"update_binding\".\n  //\n  // The initializers appears as keys in \"update_binding\" are considered\n  // mutable variables. This implies some behaviors\n  // as described below.\n  //\n  //  1. We have only unique keys in all \"update_binding\"s so that two\n  //     variables may not have the same name. This ensures that one\n  //     variable is assigned up to once.\n  //  2. The keys must appear in names of \"ModelProto.graph.initializer\" or\n  //     \"TrainingInfoProto.algorithm.initializer\".\n  //  3. The values must be output names of \"algorithm\" or \"ModelProto.graph.output\".\n  //  4. Mutable variables are initialized to the value specified by the\n  //     corresponding initializer, and then potentially updated by\n  //     \"initializer_binding\"s and \"update_binding\"s in \"TrainingInfoProto\"s.\n  //\n  // This field usually contains names of trainable tensors\n  // (in ModelProto.graph), optimizer states such as momentums in advanced\n  // stochastic gradient methods (in TrainingInfoProto.graph),\n  // and number of training iterations (in TrainingInfoProto.graph).\n  //\n  // By default, this field is empty and no initializer would be changed\n  // by the execution of \"algorithm\".\n  repeated StringStringEntryProto update_binding = 4;\n}\n\n// Models\n//\n// ModelProto is a top-level file/container format for bundling a ML model and\n// associating its computation graph with metadata.\n//\n// The semantics of the model are described by the associated GraphProto's.\nmessage ModelProto {\n  // The version of the IR this model targets. See Version enum above.\n  // This field MUST be present.\n  int64 ir_version = 1;\n\n  // The OperatorSets this model relies on.\n  // All ModelProtos MUST have at least one entry that\n  // specifies which version of the ONNX OperatorSet is\n  // being imported.\n  //\n  // All nodes in the ModelProto's graph will bind against the operator\n  // with the same-domain/same-op_type operator with the HIGHEST version\n  // in the referenced operator sets.\n  repeated OperatorSetIdProto opset_import = 8;\n\n  // The name of the framework or tool used to generate this model.\n  // This field SHOULD be present to indicate which implementation/tool/framework\n  // emitted the model.\n  string producer_name = 2;\n\n  // The version of the framework or tool used to generate this model.\n  // This field SHOULD be present to indicate which implementation/tool/framework\n  // emitted the model.\n  string producer_version = 3;\n\n  // Domain name of the model.\n  // We use reverse domain names as name space indicators. For example:\n  // `com.facebook.fair` or `com.microsoft.cognitiveservices`\n  //\n  // Together with `model_version` and GraphProto.name, this forms the unique identity of\n  // the graph.\n  string domain = 4;\n\n  // The version of the graph encoded. See Version enum below.\n  int64 model_version = 5;\n\n  // A human-readable documentation for this model. Markdown is allowed.\n  string doc_string = 6;\n\n  // The parameterized graph that is evaluated to execute the model.\n  GraphProto graph = 7;\n\n  // Named metadata values; keys should be distinct.\n  repeated StringStringEntryProto metadata_props = 14;\n\n  // Training-specific information. Sequentially executing all stored\n  // `TrainingInfoProto.algorithm`s and assigning their outputs following\n  // the corresponding `TrainingInfoProto.update_binding`s is one training\n  // iteration. Similarly, to initialize the model\n  // (as if training hasn't happened), the user should sequentially execute\n  // all stored `TrainingInfoProto.initialization`s and assigns their outputs\n  // using `TrainingInfoProto.initialization_binding`s.\n  //\n  // If this field is empty, the training behavior of the model is undefined.\n  repeated TrainingInfoProto training_info = 20;\n\n  // A list of function protos local to the model.\n  //\n  // The (domain, name, overload) tuple must be unique across the function protos in this list.\n  // In case of any conflicts the behavior (whether the model local functions are given higher priority,\n  // or standard operator sets are given higher priotity or this is treated as error) is defined by\n  // the runtimes.\n  //\n  // The operator sets imported by FunctionProto should be compatible with the ones\n  // imported by ModelProto and other model local FunctionProtos.\n  // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto\n  // or by 2 FunctionProtos then versions for the operator set may be different but,\n  // the operator schema returned for op_type, domain, version combination\n  // for both the versions should be same for every node in the function body.\n  //\n  // One FunctionProto can reference other FunctionProto in the model, however, recursive reference\n  // is not allowed.\n  repeated FunctionProto functions = 25;\n\n  // Describes different target configurations for a multi-device use case.\n  // A model MAY describe multiple multi-device configurations for execution.\n  repeated DeviceConfigurationProto configuration = 26;\n};\n\n// DeviceConfigurationProto describes a multi-device configuration for a model.\nmessage DeviceConfigurationProto {\n    // This field MUST be present for this version of the IR.\n    // Name of the configuration.\n    string name = 1;\n    // This field MUST be present for this version of the IR.\n    // Number of devices inside this configuration.\n    int32 num_devices = 2;\n    // Optional names of the devices. MUST be length of num_devices if provided.\n    repeated string device = 3;\n}\n\n// StringStringEntryProto follows the pattern for cross-proto-version maps.\n// See https://developers.google.com/protocol-buffers/docs/proto3#maps\nmessage StringStringEntryProto {\n  string key = 1;\n  string value = 2;\n};\n\nmessage TensorAnnotation {\n  string tensor_name = 1;\n  // <key, value> pairs to annotate tensor specified by <tensor_name> above.\n  // The keys used in the mapping below must be pre-defined in ONNX spec.\n  // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as\n  // quantization parameter keys.\n  repeated StringStringEntryProto quant_parameter_tensor_names = 2;\n}\n\n\n\n// Graphs\n//\n// A graph defines the computational logic of a model and is comprised of a parameterized\n// list of nodes that form a directed acyclic graph based on their inputs and outputs.\n// This is the equivalent of the \"network\" or \"graph\" in many deep learning\n// frameworks.\nmessage GraphProto {\n  // The nodes in the graph, sorted topologically.\n  repeated NodeProto node = 1;\n\n  // The name of the graph.\n  string name = 2;   // namespace Graph\n\n  // A list of named tensor values, used to specify constant inputs of the graph.\n  // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.\n  // The name MUST be unique across both initializer and sparse_initializer,\n  // but the name MAY also appear in the input list.\n  repeated TensorProto initializer = 5;\n\n  // Initializers (see above) stored in sparse format.\n  repeated SparseTensorProto sparse_initializer = 15;\n\n  // A human-readable documentation for this graph. Markdown is allowed.\n  string doc_string = 10;\n\n  // The inputs and outputs of the graph.\n  repeated ValueInfoProto input = 11;\n  repeated ValueInfoProto output = 12;\n\n  // Information for the values in the graph. The ValueInfoProto.name's\n  // must be distinct. It is optional for a value to appear in value_info list.\n  repeated ValueInfoProto value_info = 13;\n\n  // This field carries information to indicate the mapping among a tensor and its\n  // quantization parameter tensors. For example:\n  // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,\n  // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.\n  repeated TensorAnnotation quantization_annotation = 14;\n\n  // Named metadata values; keys should be distinct.\n  repeated StringStringEntryProto metadata_props = 16;\n\n  reserved 3, 4, 6 to 9;\n  reserved \"ir_version\", \"producer_version\", \"producer_tag\", \"domain\";\n}\n\n// Tensors\n//\n// A serialized tensor value.\nmessage TensorProto {\n  enum DataType {\n    UNDEFINED = 0;\n    // Basic types.\n    FLOAT = 1;   // float\n    UINT8 = 2;   // uint8_t\n    INT8 = 3;    // int8_t\n    UINT16 = 4;  // uint16_t\n    INT16 = 5;   // int16_t\n    INT32 = 6;   // int32_t\n    INT64 = 7;   // int64_t\n    STRING = 8;  // string\n    BOOL = 9;    // bool\n\n    // IEEE754 half-precision floating-point format (16 bits wide).\n    // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.\n    FLOAT16 = 10;\n\n    DOUBLE = 11;\n    UINT32 = 12;\n    UINT64 = 13;\n    COMPLEX64 = 14;     // complex with float32 real and imaginary components\n    COMPLEX128 = 15;    // complex with float64 real and imaginary components\n\n    // Non-IEEE floating-point format based on IEEE754 single-precision\n    // floating-point number truncated to 16 bits.\n    // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.\n    BFLOAT16 = 16;\n\n    // Non-IEEE floating-point format based on papers\n    // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,\n    // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.\n    // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.\n    // The computation usually happens inside a block quantize / dequantize\n    // fused by the runtime.\n    FLOAT8E4M3FN = 17;    // float 8, mostly used for coefficients, supports nan, not inf\n    FLOAT8E4M3FNUZ = 18;  // float 8, mostly used for coefficients, supports nan, not inf, no negative zero\n    FLOAT8E5M2 = 19;      // follows IEEE 754, supports nan, inf, mostly used for gradients\n    FLOAT8E5M2FNUZ = 20;  // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero\n\n    // 4-bit integer data types\n    UINT4 = 21;  // Unsigned integer in range [0, 15]\n    INT4 = 22;   // Signed integer in range [-8, 7], using two's-complement representation\n\n    // 4-bit floating point data types\n    FLOAT4E2M1 = 23;\n\n    // Future extensions go here.\n  }\n\n  // The shape of the tensor.\n  repeated int64 dims = 1;\n\n  // The data type of the tensor.\n  // This field MUST have a valid TensorProto.DataType value\n  int32 data_type = 2;\n\n  // For very large tensors, we may want to store them in chunks, in which\n  // case the following fields will specify the segment that is stored in\n  // the current TensorProto.\n  message Segment {\n    int64 begin = 1;\n    int64 end = 2;\n  }\n  Segment segment = 3;\n\n  // Tensor content must be organized in row-major order.\n  //\n  // Depending on the data_type field, exactly one of the fields below with\n  // name ending in _data is used to store the elements of the tensor.\n\n  // For float and complex64 values\n  // Complex64 tensors are encoded as a single array of floats,\n  // with the real components appearing in odd numbered positions,\n  // and the corresponding imaginary component appearing in the\n  // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]\n  // is encoded as [1.0, 2.0 ,3.0 ,4.0]\n  // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.\n  repeated float float_data = 4 [packed = true];\n\n  // For int32, uint8, int8, uint16, int16, uint4, int4, bool, (b)float16, float8, and float4:\n  // - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer\n  //   representation before being written to the buffer.\n  // - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte.\n  //   The first element is stored in the 4 least significant bits (LSB),\n  //   and the second element is stored in the 4 most significant bits (MSB).\n  //\n  // Consequently:\n  // - For data types with a bit-width of 8 or greater, each `int32_data` stores one element.\n  // - For 4-bit data types, each `int32_data` stores two elements.\n  //\n  // When this field is present, the data_type field MUST be\n  // INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT4E2M1\n  repeated int32 int32_data = 5 [packed = true];\n\n  // For strings.\n  // Each element of string_data is a UTF-8 encoded Unicode\n  // string. No trailing null, no leading BOM. The protobuf \"string\"\n  // scalar type is not used to match ML community conventions.\n  // When this field is present, the data_type field MUST be STRING\n  repeated bytes string_data = 6;\n\n  // For int64.\n  // When this field is present, the data_type field MUST be INT64\n  repeated int64 int64_data = 7 [packed = true];\n\n  // Optionally, a name for the tensor.\n  string name = 8; // namespace Value\n\n  // A human-readable documentation for this tensor. Markdown is allowed.\n  string doc_string = 12;\n\n  // Serializations can either use one of the fields above, or use this\n  // raw bytes field. The only exception is the string case, where one is\n  // required to store the content in the repeated bytes string_data field.\n  //\n  // When this raw_data field is used to store tensor value, elements MUST\n  // be stored in as fixed-width, little-endian order.\n  // Floating-point data types MUST be stored in IEEE 754 format.\n  // Complex64 elements must be written as two consecutive FLOAT values, real component first.\n  // Complex128 elements must be written as two consecutive DOUBLE values, real component first.\n  // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).\n  // uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB.\n  //\n  // Note: the advantage of specific field rather than the raw_data field is\n  // that in some cases (e.g. int data), protobuf does a better packing via\n  // variable length storage, and may lead to smaller binary footprint.\n  // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED\n  bytes raw_data = 9;\n\n  // Data can be stored inside the protobuf file using type-specific fields or raw_data.\n  // Alternatively, raw bytes data can be stored in an external file, using the external_data field.\n  // external_data stores key-value pairs describing data location. Recognized keys are:\n  // - \"location\" (required) - POSIX filesystem path relative to the directory where the ONNX\n  //                           protobuf model was stored\n  // - \"offset\" (optional) - position of byte at which stored data begins. Integer stored as string.\n  //                         Offset values SHOULD be multiples 4096 (page size) to enable mmap support.\n  // - \"length\" (optional) - number of bytes containing data. Integer stored as string.\n  // - \"checksum\" (optional) - SHA1 digest of file specified in under 'location' key.\n  repeated StringStringEntryProto external_data = 13;\n\n  // Location of the data for this tensor. MUST be one of:\n  // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.\n  // - EXTERNAL - data stored in an external location as described by external_data field.\n  enum DataLocation {\n    DEFAULT = 0;\n    EXTERNAL = 1;\n  }\n\n  // If value not set, data is stored in raw_data (if set) otherwise in type-specified field.\n  DataLocation data_location = 14;\n\n  // For double\n  // Complex128 tensors are encoded as a single array of doubles,\n  // with the real components appearing in odd numbered positions,\n  // and the corresponding imaginary component appearing in the\n  // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]\n  // is encoded as [1.0, 2.0 ,3.0 ,4.0]\n  // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128\n  repeated double double_data = 10 [packed = true];\n\n  // For uint64 and uint32 values\n  // When this field is present, the data_type field MUST be\n  // UINT32 or UINT64\n  repeated uint64 uint64_data = 11 [packed = true];\n\n  // Named metadata values; keys should be distinct.\n  repeated StringStringEntryProto metadata_props = 16;\n}\n\n// A serialized sparse-tensor value\nmessage SparseTensorProto {\n  // The sequence of non-default values are encoded as a tensor of shape [NNZ].\n  // The default-value is zero for numeric tensors, and empty-string for string tensors.\n  // values must have a non-empty name present which serves as a name for SparseTensorProto\n  // when used in sparse_initializer list.\n  TensorProto values = 1;\n\n  // The indices of the non-default values, which may be stored in one of two formats.\n  // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value\n  // corresponding to the j-th index of the i-th value (in the values tensor).\n  // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value\n  // must be the linearized-index of the i-th value (in the values tensor).\n  // The linearized-index can be converted into an index tuple (k_1,...,k_rank)\n  // using the shape provided below.\n  // The indices must appear in ascending order without duplication.\n  // In the first format, the ordering is lexicographic-ordering:\n  // e.g., index-value [1,4] must appear before [2,1]\n  TensorProto indices = 2;\n\n  // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]\n  repeated int64 dims = 3;\n}\n\n// Defines a tensor shape. A dimension can be either an integer value\n// or a symbolic variable. A symbolic variable represents an unknown\n// dimension.\nmessage TensorShapeProto {\n  message Dimension {\n    oneof value {\n      int64 dim_value = 1;\n      string dim_param = 2;   // namespace Shape\n    };\n    // Standard denotation can optionally be used to denote tensor\n    // dimensions with standard semantic descriptions to ensure\n    // that operations are applied to the correct axis of a tensor.\n    // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition\n    // for pre-defined dimension denotations.\n    string denotation = 3;\n  };\n  repeated Dimension dim = 1;\n}\n\n// Types\n//\n// The standard ONNX data types.\nmessage TypeProto {\n\n  message Tensor {\n    // This field MUST NOT have the value of UNDEFINED\n    // This field MUST have a valid TensorProto.DataType value\n    // This field MUST be present for this version of the IR.\n    int32 elem_type = 1;\n    TensorShapeProto shape = 2;\n  }\n\n  // repeated T\n  message Sequence {\n    // The type and optional shape of each element of the sequence.\n    // This field MUST be present for this version of the IR.\n    TypeProto elem_type = 1;\n  };\n\n  // map<K,V>\n  message Map {\n    // This field MUST have a valid TensorProto.DataType value\n    // This field MUST be present for this version of the IR.\n    // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING\n    int32 key_type = 1;\n    // This field MUST be present for this version of the IR.\n    TypeProto value_type = 2;\n  };\n\n  // wrapper for Tensor, Sequence, or Map\n  message Optional {\n    // The type and optional shape of the element wrapped.\n    // This field MUST be present for this version of the IR.\n    // Possible values correspond to OptionalProto.DataType enum\n    TypeProto elem_type = 1;\n  };\n\n\n  message SparseTensor {\n    // This field MUST NOT have the value of UNDEFINED\n    // This field MUST have a valid TensorProto.DataType value\n    // This field MUST be present for this version of the IR.\n    int32 elem_type = 1;\n    TensorShapeProto shape = 2;\n  }\n\n\n  oneof value {\n    // The type of a tensor.\n    Tensor tensor_type = 1;\n\n    // NOTE:  DNN-only implementations of ONNX MAY elect to not support non-tensor values\n    //        as input and output to graphs and nodes. These types are needed to naturally\n    //        support classical ML operators.  DNN operators SHOULD restrict their input\n    //        and output types to tensors.\n\n    // The type of a sequence.\n    Sequence sequence_type = 4;\n\n    // The type of a map.\n    Map map_type = 5;\n\n    // The type of an optional.\n    Optional optional_type = 9;\n\n\n    // Type of the sparse tensor\n    SparseTensor sparse_tensor_type = 8;\n\n  }\n\n  // An optional denotation can be used to denote the whole\n  // type with a standard semantic description as to what is\n  // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition\n  // for pre-defined type denotations.\n  string denotation = 6;\n}\n\n// Operator Sets\n//\n// OperatorSets are uniquely identified by a (domain, opset_version) pair.\nmessage OperatorSetIdProto {\n  // The domain of the operator set being identified.\n  // The empty string (\"\") or absence of this field implies the operator\n  // set that is defined as part of the ONNX specification.\n  // This field MUST be present in this version of the IR when referring to any other operator set.\n  string domain = 1;\n\n  // The version of the operator set being identified.\n  // This field MUST be present in this version of the IR.\n  int64 version = 2;\n}\n\n// Operator/function status.\nenum OperatorStatus {\n    EXPERIMENTAL = 0;\n    STABLE = 1;\n}\n\nmessage FunctionProto {\n  // The name of the function, similar to op_type in NodeProto.\n  // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.\n  string name = 1;\n\n  // Deprecated since IR Version 8\n  // optional int64 since_version = 2;\n  reserved 2;\n  reserved \"since_version\";\n\n  // Deprecated since IR Version 8\n  // optional OperatorStatus status = 3;\n  reserved 3;\n  reserved \"status\";\n\n  // The inputs and outputs of the function.\n  repeated string input = 4;\n  repeated string output = 5;\n\n  // The attribute parameters of the function.\n  // It is for function parameters without default values.\n  repeated string attribute = 6;\n\n  // The attribute protos of the function.\n  // It is for function attributes with default values.\n  // A function attribute shall be represented either as\n  // a string attribute or an AttributeProto, not both.\n  repeated AttributeProto attribute_proto = 11;\n\n  // The nodes in the function.\n  repeated NodeProto node = 7;\n  // A human-readable documentation for this function. Markdown is allowed.\n  string doc_string = 8;\n\n  // The OperatorSets this function body (graph) relies on.\n  //\n  // All nodes in the function body (graph) will bind against the operator\n  // with the same-domain/same-op_type operator with the HIGHEST version\n  // in the referenced operator sets. This means at most one version can be relied\n  // for one domain.\n  //\n  // The operator sets imported by FunctionProto should be compatible with the ones\n  // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto\n  // and ModelProto then versions for the operator set may be different but,\n  // the operator schema returned for op_type, domain, version combination\n  // for both the versions should be same.\n\n  repeated OperatorSetIdProto opset_import = 9;\n\n  // The domain which this function belongs to.\n  // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.\n  string domain = 10;\n\n  // The overload identifier of the function.\n  // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model.\n  string overload = 13;\n\n  // Information for the values in the function. The ValueInfoProto.name's\n  // must be distinct and refer to names in the function (including inputs,\n  // outputs, and intermediate values). It is optional for a value to appear\n  // in value_info list.\n  repeated ValueInfoProto value_info = 12;\n\n  // Named metadata values; keys should be distinct.\n  repeated StringStringEntryProto metadata_props = 14;\n}\n\n// For using protobuf-lite\noption optimize_for = LITE_RUNTIME;\n\n"
  },
  {
    "path": "crates/dsperse/src/backend/jstprove.rs",
    "content": "use std::collections::HashMap;\nuse std::path::{Path, PathBuf};\nuse std::sync::{Arc, Mutex};\n\npub use jstprove_circuits::api::ExtractedOutputType as ExtractedOutput;\npub use jstprove_circuits::api::ProofConfigType as ProofConfig;\npub use jstprove_circuits::api::StampedProofConfigType as StampedProofConfig;\npub use jstprove_circuits::api::VerifiedOutputType as VerifiedOutput;\nuse jstprove_circuits::api::{\n    self, ArchitectureType as Architecture, CircuitParamsType as CircuitParams,\n    CompiledCircuitType as CompiledCircuit, WANDBType as WANDB,\n};\nuse jstprove_circuits::runner::schema::WitnessRequest;\n\nuse crate::error::{DsperseError, Result};\n\nuse super::traits::ProofBackend;\n\n#[derive(Debug)]\npub struct JstproveBackend {\n    compress: bool,\n    bundle_cache: Mutex<HashMap<PathBuf, Arc<CompiledCircuit>>>,\n}\n\nimpl Default for JstproveBackend {\n    fn default() -> Self {\n        Self {\n            compress: true,\n            bundle_cache: Mutex::new(HashMap::new()),\n        }\n    }\n}\n\nimpl JstproveBackend {\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    pub fn with_compress(mut self, compress: bool) -> Self {\n        self.compress = compress;\n        self\n    }\n\n    pub fn compress(&self) -> bool {\n        self.compress\n    }\n\n    pub fn load_bundle_cached(&self, path: &Path) -> Result<Arc<CompiledCircuit>> {\n        let key = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());\n\n        let mut cache = self\n            .bundle_cache\n            .lock()\n            .map_err(|e| DsperseError::Backend(format!(\"bundle cache lock poisoned: {e}\")))?;\n        if let Some(bundle) = cache.get(&key) {\n            return Ok(Arc::clone(bundle));\n        }\n        let bundle = Arc::new(load_bundle(path)?);\n        cache.insert(key, Arc::clone(&bundle));\n\n        Ok(bundle)\n    }\n\n    pub fn clear_cache(&self) {\n        let mut cache = match self.bundle_cache.lock() {\n            Ok(cache) => cache,\n            Err(e) => {\n                tracing::warn!(\"bundle cache lock poisoned on clear: {e}\");\n                e.into_inner()\n            }\n        };\n        let count = cache.len();\n        cache.clear();\n        tracing::debug!(cleared = count, \"bundle cache cleared\");\n    }\n\n    /// Evict cached bundles whose canonical path starts with the given\n    /// prefix. Used by callers that want to drop a model's entries\n    /// without clearing the entire cache.\n    pub fn evict_cache_by_prefix(&self, prefix: &Path) {\n        let mut cache = match self.bundle_cache.lock() {\n            Ok(cache) => cache,\n            Err(e) => {\n                tracing::warn!(\"bundle cache lock poisoned on evict: {e}\");\n                e.into_inner()\n            }\n        };\n        let before = cache.len();\n        cache.retain(|k, _| !k.starts_with(prefix));\n        let evicted = before - cache.len();\n        if evicted > 0 {\n            tracing::info!(\n                prefix = %prefix.display(),\n                evicted,\n                remaining = cache.len(),\n                \"evicted bundle cache entries\"\n            );\n        }\n    }\n\n    /// Resolve the proof config for a freshly loaded bundle. Errors if\n    /// the bundle does not carry a stamped proof config or if the\n    /// stamped version does not match the current spec, so callers can\n    /// fail fast on legacy or incompatible bundles instead of running\n    /// the wrong prover.\n    fn resolve_proof_config(bundle: &CompiledCircuit) -> Result<ProofConfig> {\n        let stamped = bundle\n            .metadata\n            .as_ref()\n            .and_then(|m| m.proof_config)\n            .ok_or_else(|| {\n                DsperseError::Backend(\n                    \"circuit bundle has no stamped proof_config; recompile with a stamping prover\"\n                        .into(),\n                )\n            })?;\n        stamped\n            .ensure_current()\n            .map_err(|e| DsperseError::Backend(format!(\"incompatible bundle: {e}\")))?;\n        Ok(stamped.config)\n    }\n\n    /// Resolve the proof config without touching the circuit or\n    /// witness-solver blobs. Reads only `manifest.msgpack`, which is\n    /// kilobytes versus the tens of megabytes a full bundle load\n    /// pulls in. Falls back to `resolve_proof_config` on a full\n    /// bundle load if the manifest is missing the stamp so callers\n    /// still get the same \"no stamped proof_config\" error path for\n    /// legacy bundles rather than a confusing deserialization\n    /// failure.\n    fn resolve_proof_config_from_manifest(&self, circuit_path: &Path) -> Result<ProofConfig> {\n        match jstprove_io::bundle::read_bundle_metadata::<CircuitParams>(circuit_path) {\n            Ok((Some(params), _)) => {\n                let stamped = params.proof_config.ok_or_else(|| {\n                    DsperseError::Backend(\n                        \"circuit bundle has no stamped proof_config; recompile with a stamping prover\"\n                            .into(),\n                    )\n                })?;\n                stamped\n                    .ensure_current()\n                    .map_err(|e| DsperseError::Backend(format!(\"incompatible bundle: {e}\")))?;\n                Ok(stamped.config)\n            }\n            Ok((None, _)) => {\n                let bundle = self.load_bundle_cached(circuit_path)?;\n                Self::resolve_proof_config(&bundle)\n            }\n            Err(e) => {\n                // Surface the manifest-read failure so operators\n                // investigating a slow verify path or a legacy\n                // bundle layout can tell the fast path missed\n                // rather than silently eating a parse / IO error.\n                tracing::debug!(\n                    path = %circuit_path.display(),\n                    error = %e,\n                    \"manifest-only proof_config read failed; falling back to full bundle load\"\n                );\n                let bundle = self.load_bundle_cached(circuit_path)?;\n                Self::resolve_proof_config(&bundle)\n            }\n        }\n    }\n\n    pub fn compile(\n        &self,\n        circuit_path: &Path,\n        config: ProofConfig,\n        params: CircuitParams,\n        architecture: Architecture,\n        wandb: WANDB,\n    ) -> Result<()> {\n        let circuit_path_str = circuit_path\n            .to_str()\n            .ok_or_else(|| DsperseError::Backend(\"non-UTF8 circuit path\".into()))?;\n\n        api::compile(\n            circuit_path_str,\n            config,\n            params,\n            architecture,\n            wandb,\n            self.compress,\n        )\n        .map_err(|e| DsperseError::Backend(format!(\"compile: {e}\")))?;\n\n        let key = circuit_path\n            .canonicalize()\n            .unwrap_or_else(|_| circuit_path.to_path_buf());\n        self.bundle_cache\n            .lock()\n            .map_err(|e| DsperseError::Backend(format!(\"bundle cache lock poisoned: {e}\")))?\n            .remove(&key);\n\n        Ok(())\n    }\n\n    pub fn witness(\n        &self,\n        circuit_path: &Path,\n        input_json: &[u8],\n        output_json: &[u8],\n    ) -> Result<Vec<u8>> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        let config = Self::resolve_proof_config(&bundle)?;\n\n        let req = WitnessRequest {\n            circuit: bundle.circuit.clone(),\n            witness_solver: bundle.witness_solver.clone(),\n            inputs: input_json.to_vec(),\n            outputs: output_json.to_vec(),\n            metadata: bundle.metadata.clone(),\n        };\n\n        let result = api::witness(config, &req, self.compress)\n            .map_err(|e| DsperseError::Backend(format!(\"witness: {e}\")))?;\n\n        Ok(result.witness)\n    }\n\n    pub fn witness_f64(\n        &self,\n        circuit_path: &Path,\n        activations: &[f64],\n        initializers: &[(Vec<f64>, Vec<usize>)],\n    ) -> Result<Vec<u8>> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        let config = Self::resolve_proof_config(&bundle)?;\n        let params = bundle.metadata.as_ref().ok_or_else(|| {\n            DsperseError::Backend(\n                \"circuit bundle missing metadata (required for quantization)\".into(),\n            )\n        })?;\n\n        let result = api::witness_f64(\n            config,\n            &bundle.circuit,\n            &bundle.witness_solver,\n            params,\n            activations,\n            initializers,\n            self.compress,\n        )\n        .map_err(|e| DsperseError::Backend(format!(\"witness_f64: {e}\")))?;\n\n        Ok(result.witness)\n    }\n\n    pub fn load_params(&self, circuit_path: &Path) -> Result<Option<CircuitParams>> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        Ok(bundle.metadata.clone())\n    }\n\n    pub fn prove(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result<Vec<u8>> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        let config = Self::resolve_proof_config(&bundle)?;\n\n        api::prove(config, &bundle.circuit, witness_bytes, self.compress)\n            .map_err(|e| DsperseError::Backend(format!(\"prove: {e}\")))\n    }\n\n    pub fn extract_outputs(\n        &self,\n        witness_bytes: &[u8],\n        num_model_inputs: usize,\n    ) -> Result<Vec<f64>> {\n        Ok(self\n            .extract_outputs_full(witness_bytes, num_model_inputs)?\n            .outputs)\n    }\n\n    /// Full extracted output bundle: inputs, outputs, and the\n    /// witness-stamped scale parameters. Holographic verifiers call\n    /// this after `verify_holographic` because the holographic\n    /// verify path does not reach through `verify_and_extract`, yet\n    /// the validator still needs the declared inputs (to cross-check\n    /// against what it sent) and the scale fields (to report the\n    /// same `VerifiedOutput` shape the non-holographic path\n    /// produces). Keeping `extract_outputs` as a thin wrapper\n    /// preserves the existing `Vec<f64>` contract for callers that\n    /// only want the outputs.\n    pub fn extract_outputs_full(\n        &self,\n        witness_bytes: &[u8],\n        num_model_inputs: usize,\n    ) -> Result<ExtractedOutput> {\n        if num_model_inputs == 0 {\n            return Err(DsperseError::Backend(\n                \"extract_outputs: num_model_inputs must be > 0\".into(),\n            ));\n        }\n        api::extract_outputs(witness_bytes, num_model_inputs)\n            .map_err(|e| DsperseError::Backend(format!(\"extract_outputs: {e}\")))\n    }\n\n    pub fn verify(\n        &self,\n        circuit_path: &Path,\n        witness_bytes: &[u8],\n        proof_bytes: &[u8],\n    ) -> Result<bool> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        let config = Self::resolve_proof_config(&bundle)?;\n\n        api::verify(config, &bundle.circuit, witness_bytes, proof_bytes)\n            .map_err(|e| DsperseError::Backend(format!(\"verify: {e}\")))\n    }\n\n    pub fn verify_and_extract(\n        &self,\n        circuit_path: &Path,\n        witness_bytes: &[u8],\n        proof_bytes: &[u8],\n        num_inputs: usize,\n        expected_inputs: Option<&[f64]>,\n    ) -> Result<VerifiedOutput> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        let config = Self::resolve_proof_config(&bundle)?;\n\n        api::verify_and_extract(\n            config,\n            &bundle.circuit,\n            witness_bytes,\n            proof_bytes,\n            num_inputs,\n            expected_inputs,\n        )\n        .map_err(|e| DsperseError::Backend(format!(\"verify_and_extract: {e}\")))\n    }\n\n    /// Run holographic GKR setup against the compiled circuit at\n    /// `circuit_path` and persist the resulting verifying key as\n    /// `vk.bin` inside the bundle directory. The bundle is read from\n    /// the cache, so callers that just compiled the bundle through\n    /// [`Self::compile`] pay only the holographic setup cost on top.\n    ///\n    /// `setup_holographic_vk` only succeeds when the bundle was\n    /// compiled with `ProofConfig::GoldilocksExt4Whir`; the underlying\n    /// jstprove API rejects every other config.\n    ///\n    /// The vk blob is written using the same compression mode as the\n    /// rest of the bundle (`Self::compress`) so\n    /// `jstprove_io::bundle::read_vk_only` can decode it via the\n    /// shared auto-detecting reader.\n    pub fn setup_holographic_vk(&self, circuit_path: &Path) -> Result<()> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        let config = Self::resolve_proof_config(&bundle)?;\n\n        let vk_bytes = api::setup_holographic_vk(config, &bundle.circuit)\n            .map_err(|e| DsperseError::Backend(format!(\"setup_holographic_vk: {e}\")))?;\n\n        let vk_path = circuit_path.join(\"vk.bin\");\n        let payload = if self.compress {\n            jstprove_io::compress_bytes(&vk_bytes)\n                .map_err(|e| DsperseError::Backend(format!(\"compress vk: {e}\")))?\n        } else {\n            vk_bytes\n        };\n        std::fs::write(&vk_path, &payload).map_err(|e| DsperseError::io(e, &vk_path))?;\n        Ok(())\n    }\n\n    /// Generate a holographic GKR proof for an existing bundle and\n    /// witness. Like [`Self::setup_holographic_vk`] this requires the\n    /// bundle to have been compiled with\n    /// `ProofConfig::GoldilocksExt4Whir`.\n    pub fn prove_holographic(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result<Vec<u8>> {\n        let bundle = self.load_bundle_cached(circuit_path)?;\n        let config = Self::resolve_proof_config(&bundle)?;\n\n        api::prove_holographic(config, &bundle.circuit, witness_bytes)\n            .map_err(|e| DsperseError::Backend(format!(\"prove_holographic: {e}\")))\n    }\n\n    /// Verify a holographic GKR proof against the bundle's vk.bin.\n    /// The vk is read independently of the (much larger) circuit\n    /// blob, mirroring the validator-side flow where the verifying\n    /// party only ever ships the vk.\n    pub fn verify_holographic(&self, circuit_path: &Path, proof_bytes: &[u8]) -> Result<bool> {\n        // Verifiers only need the vk and the proof config — the\n        // circuit and witness solver blobs are not used downstream.\n        // Skip load_bundle_cached here so validators that only ever\n        // hold vk.bin + manifest.msgpack (the intended light-weight\n        // deployment shape) don't fail with a missing circuit.bin\n        // and don't pay the tens-of-megabytes read cost.\n        let config = self.resolve_proof_config_from_manifest(circuit_path)?;\n        let vk_bytes = jstprove_io::bundle::read_vk_only(circuit_path)\n            .map_err(|e| DsperseError::Backend(format!(\"read vk: {e}\")))?;\n\n        api::verify_holographic(config, &vk_bytes, proof_bytes)\n            .map_err(|e| DsperseError::Backend(format!(\"verify_holographic: {e}\")))\n    }\n}\n\nimpl ProofBackend for JstproveBackend {\n    fn prove(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result<Vec<u8>> {\n        self.prove(circuit_path, witness_bytes)\n    }\n\n    fn verify(\n        &self,\n        circuit_path: &Path,\n        witness_bytes: &[u8],\n        proof_bytes: &[u8],\n    ) -> Result<bool> {\n        self.verify(circuit_path, witness_bytes, proof_bytes)\n    }\n\n    fn witness_f64(\n        &self,\n        circuit_path: &Path,\n        activations: &[f64],\n        initializers: &[(Vec<f64>, Vec<usize>)],\n    ) -> Result<Vec<u8>> {\n        self.witness_f64(circuit_path, activations, initializers)\n    }\n}\n\nfn load_bundle(circuit_path: &Path) -> Result<CompiledCircuit> {\n    let path_str = circuit_path\n        .to_str()\n        .ok_or_else(|| DsperseError::Backend(\"non-UTF8 circuit path\".into()))?;\n\n    api::read_circuit_bundle(path_str)\n        .map_err(|e| DsperseError::Backend(format!(\"read circuit bundle: {e}\")))\n}\n\npub struct WarmCircuit {\n    bundle: Arc<CompiledCircuit>,\n    pub params: CircuitParams,\n    initializers: Vec<(Vec<f64>, Vec<usize>)>,\n    compress: bool,\n    config: ProofConfig,\n}\n\nimpl WarmCircuit {\n    pub fn load(\n        circuit_path: &Path,\n        initializers: Vec<(Vec<f64>, Vec<usize>)>,\n        backend: &JstproveBackend,\n    ) -> Result<Self> {\n        let bundle = backend.load_bundle_cached(circuit_path)?;\n        let config = JstproveBackend::resolve_proof_config(&bundle)?;\n        let params = bundle\n            .metadata\n            .clone()\n            .ok_or_else(|| DsperseError::Backend(\"circuit bundle missing metadata\".into()))?;\n        Ok(Self {\n            bundle,\n            params,\n            initializers,\n            compress: backend.compress(),\n            config,\n        })\n    }\n\n    pub fn witness_f64(&self, activations: &[f64]) -> Result<Vec<u8>> {\n        let result = api::witness_f64(\n            self.config,\n            &self.bundle.circuit,\n            &self.bundle.witness_solver,\n            &self.params,\n            activations,\n            &self.initializers,\n            self.compress,\n        )\n        .map_err(|e| DsperseError::Backend(format!(\"witness_f64: {e}\")))?;\n\n        Ok(result.witness)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn bundle_cache_starts_empty() {\n        let backend = JstproveBackend::default();\n        let cache = backend.bundle_cache.lock().unwrap();\n        assert!(cache.is_empty());\n    }\n\n    #[test]\n    fn backend_constructs_without_proof_config_state() {\n        let backend = JstproveBackend::default();\n        assert!(backend.compress());\n    }\n\n    #[test]\n    fn clear_cache_on_empty_succeeds() {\n        let backend = JstproveBackend::default();\n        backend.clear_cache();\n        let cache = backend.bundle_cache.lock().unwrap();\n        assert!(cache.is_empty());\n    }\n\n    #[test]\n    fn clear_cache_removes_entries() {\n        let backend = JstproveBackend::default();\n        let dummy = Arc::new(CompiledCircuit {\n            circuit: vec![1, 2, 3],\n            witness_solver: vec![],\n            metadata: None,\n            version: None,\n        });\n        backend\n            .bundle_cache\n            .lock()\n            .unwrap()\n            .insert(PathBuf::from(\"/tmp/test-circuit\"), dummy);\n        assert_eq!(backend.bundle_cache.lock().unwrap().len(), 1);\n        backend.clear_cache();\n        assert!(backend.bundle_cache.lock().unwrap().is_empty());\n    }\n\n    #[test]\n    fn load_bundle_cached_returns_error_for_missing_path() {\n        let backend = JstproveBackend::default();\n        let result = backend.load_bundle_cached(Path::new(\"/nonexistent/circuit/path\"));\n        assert!(result.is_err());\n        assert!(backend.bundle_cache.lock().unwrap().is_empty());\n    }\n\n    #[test]\n    fn resolve_proof_config_rejects_unstamped_bundle() {\n        let bundle = CompiledCircuit {\n            circuit: vec![],\n            witness_solver: vec![],\n            metadata: None,\n            version: None,\n        };\n        let err = JstproveBackend::resolve_proof_config(&bundle).unwrap_err();\n        match err {\n            DsperseError::Backend(msg) => {\n                assert!(msg.contains(\"no stamped proof_config\"), \"{msg}\")\n            }\n            other => panic!(\"expected Backend error, got {other:?}\"),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/backend/mod.rs",
    "content": "pub mod jstprove;\npub mod onnx;\npub mod traits;\n\npub use traits::ProofBackend;\n"
  },
  {
    "path": "crates/dsperse/src/backend/onnx.rs",
    "content": "use std::collections::HashMap;\nuse std::path::Path;\nuse std::sync::Arc;\n\nuse ndarray::IxDyn;\nuse tract_onnx::prelude::*;\nuse tract_onnx::tract_hir::infer::Factoid;\n\nuse crate::error::{DsperseError, Result};\n\npub fn coerce_tdim_inputs(inputs: &TVec<TValue>) -> TVec<TValue> {\n    inputs\n        .iter()\n        .map(|t| {\n            if t.datum_type() == DatumType::TDim {\n                // Safety: datum_type() == TDim verified by outer condition\n                let view = unsafe { t.as_slice_unchecked::<TDim>() };\n                let vals: Vec<i64> = view.iter().map(|d| d.to_i64().unwrap_or(0)).collect();\n                Tensor::from_shape(t.shape(), &vals)\n                    .map(|t| t.into_tvalue())\n                    .unwrap_or_else(|_| t.clone())\n            } else {\n                t.clone()\n            }\n        })\n        .collect()\n}\n\npub type NamedOutputs = HashMap<String, (Vec<f64>, Vec<usize>)>;\n\nfn load_onnx_model(onnx_path: &Path) -> Result<InferenceModel> {\n    tract_onnx::onnx()\n        .model_for_path(onnx_path)\n        .map_err(|e| DsperseError::Onnx(format!(\"load {}: {e}\", onnx_path.display())))\n}\n\nfn resolve_concrete_shape(model: &InferenceModel, input_shape: &[usize]) -> Result<Vec<usize>> {\n    let model_shape = model\n        .input_fact(0)\n        .ok()\n        .and_then(|f| f.shape.as_concrete_finite().ok().flatten())\n        .map(|s| s.to_vec());\n\n    if input_shape.is_empty() {\n        return model_shape.ok_or_else(|| {\n            DsperseError::Onnx(\"symbolic input shape — provide explicit shape\".into())\n        });\n    }\n\n    if let Some(ref ms) = model_shape {\n        let model_elems: usize = ms.iter().product();\n        let input_elems: usize = input_shape.iter().product();\n        if input_shape.len() == 1 && ms.len() > 1 && model_elems == input_elems {\n            tracing::debug!(\n                model_shape = ?ms,\n                provided_shape = ?input_shape,\n                \"reshaping flat input to model-declared shape\"\n            );\n            return Ok(ms.clone());\n        }\n    }\n\n    Ok(input_shape.to_vec())\n}\n\nfn resolve_input_datum_type(model: &InferenceModel, idx: usize) -> Result<DatumType> {\n    let fact = model\n        .input_fact(idx)\n        .map_err(|e| DsperseError::Onnx(format!(\"input fact at index {idx}: {e}\")))?;\n    fact.datum_type.concretize().ok_or_else(|| {\n        DsperseError::Onnx(format!(\n            \"input fact at index {idx} has no concrete datum type; the model must declare a concrete element type for this input\"\n        ))\n    })\n}\n\nfn optimize_to_runnable(\n    model: InferenceModel,\n    concrete_shape: &[usize],\n    input_dt: DatumType,\n) -> Result<Arc<TypedRunnableModel>> {\n    model\n        .with_input_fact(0, InferenceFact::dt_shape(input_dt, concrete_shape))\n        .map_err(|e| DsperseError::Onnx(format!(\"set input shape: {e}\")))?\n        .into_optimized()\n        .map_err(|e| DsperseError::Onnx(format!(\"optimize: {e:#}\")))?\n        .into_runnable()\n        .map_err(|e| DsperseError::Onnx(format!(\"make runnable: {e:#}\")))\n}\n\npub fn run_inference_with_coercion(\n    onnx_path: &Path,\n    input_data: &[f64],\n    input_shape: &[usize],\n) -> Result<NamedOutputs> {\n    let model = load_onnx_model(onnx_path)?;\n    let concrete_shape = resolve_concrete_shape(&model, input_shape)?;\n    let input_dt = resolve_input_datum_type(&model, 0)?;\n\n    if let Ok(plan) = optimize_to_runnable(model, &concrete_shape, input_dt) {\n        let input = build_input_tvalue(input_data, &concrete_shape, input_dt)?;\n        let result = plan\n            .run(tvec![input])\n            .map_err(|e| DsperseError::Onnx(format!(\"run: {e:#}\")))?;\n        return extract_all_outputs(&result);\n    }\n\n    tracing::warn!(\"standard optimization failed; using inference plan with TDim coercion\");\n    let model2 = load_onnx_model(onnx_path)?;\n    let with_shape = model2\n        .with_input_fact(0, InferenceFact::dt_shape(input_dt, &concrete_shape))\n        .map_err(|e| DsperseError::Onnx(format!(\"set input: {e}\")))?;\n\n    let plan =\n        tract_onnx::tract_hir::infer::InferenceSimplePlan::new(std::sync::Arc::new(with_shape))\n            .map_err(|e| DsperseError::Onnx(format!(\"inference plan: {e}\")))?;\n    let mut state = tract_onnx::tract_core::plan::SimpleState::new(&plan)\n        .map_err(|e| DsperseError::Onnx(format!(\"state: {e}\")))?;\n\n    let input = build_input_tvalue(input_data, &concrete_shape, input_dt)?;\n    let result = state\n        .run_plan_with_eval(tvec![input], |session, op_state, node, inputs| {\n            let coerced = coerce_tdim_inputs(&inputs);\n            let eval_result = if let Some(st) = op_state {\n                st.eval(session, node.op.as_op(), coerced)\n            } else {\n                node.op.eval(coerced)\n            };\n            match eval_result {\n                Ok(o) => Ok::<_, TractError>(o),\n                Err(e) => {\n                    let Some(first) = inputs.first() else {\n                        return Err(e);\n                    };\n                    tracing::warn!(node = %node.name, error = %e, \"eval failed, using fallback\");\n                    let dt = first.datum_type();\n                    let fallback = Tensor::zero_dt(dt, &[1])\n                        .map_err(|alloc_err| {\n                            TractError::msg(format!(\n                                \"node {}: eval failed ({e}); fallback allocation for dtype {dt:?} failed: {alloc_err}\",\n                                node.name\n                            ))\n                        })?\n                        .into_tvalue();\n                    let n = node.outputs.len().max(1);\n                    Ok((0..n).map(|_| fallback.clone()).collect())\n                }\n            }\n        })\n        .map_err(|e| DsperseError::Onnx(format!(\"inference run: {e:#}\")))?;\n\n    extract_all_outputs(&result)\n}\n\nfn extract_all_outputs(result: &[TValue]) -> Result<NamedOutputs> {\n    let mut outputs = NamedOutputs::new();\n    for (i, tv) in result.iter().enumerate() {\n        let label = format!(\"output_{i}\");\n        let (data, shape) = tvalue_to_f64(tv, &label)?;\n        outputs.insert(label, (data, shape));\n    }\n    Ok(outputs)\n}\n\nfn load_runnable(\n    onnx_path: &Path,\n    input_shape: &[usize],\n) -> Result<(Arc<TypedRunnableModel>, Vec<usize>, DatumType)> {\n    let model = load_onnx_model(onnx_path)?;\n    let concrete_shape = resolve_concrete_shape(&model, input_shape)?;\n    let input_dt = resolve_input_datum_type(&model, 0)?;\n    let plan = optimize_to_runnable(model, &concrete_shape, input_dt)?;\n    Ok((plan, concrete_shape, input_dt))\n}\n\nconst I64_SAFE_BOUND_F64: f64 = I64_SAFE_BOUND as f64;\n\nfn reject_non_finite(v: f64, idx: usize, type_name: &str) -> Result<()> {\n    if !v.is_finite() {\n        return Err(DsperseError::Onnx(format!(\n            \"input[{idx}] = {v}: non-finite values are not accepted for {type_name} inputs\"\n        )));\n    }\n    Ok(())\n}\n\nfn validate_integer_input(\n    v: f64,\n    idx: usize,\n    type_name: &str,\n    type_min: f64,\n    type_max: f64,\n) -> Result<()> {\n    reject_non_finite(v, idx, type_name)?;\n    if v.trunc() != v {\n        return Err(DsperseError::Onnx(format!(\n            \"input[{idx}] = {v}: fractional component cannot be represented as {type_name}\"\n        )));\n    }\n    if v.abs() > I64_SAFE_BOUND_F64 {\n        return Err(DsperseError::Onnx(format!(\n            \"input[{idx}] = {v}: magnitude exceeds IEEE-754 safe integer bound {I64_SAFE_BOUND}\"\n        )));\n    }\n    if v < type_min || v > type_max {\n        return Err(DsperseError::Onnx(format!(\n            \"input[{idx}] = {v}: outside representable range [{type_min}, {type_max}] for {type_name}\"\n        )));\n    }\n    Ok(())\n}\n\nfn build_input_tvalue(input_data: &[f64], shape: &[usize], dt: DatumType) -> Result<TValue> {\n    let f32_max_f64: f64 = f32::MAX as f64;\n    macro_rules! build_bounded_int {\n        ($t:ty, $name:expr, $min:expr, $max:expr) => {{\n            let mut data: Vec<$t> = Vec::with_capacity(input_data.len());\n            for (i, &v) in input_data.iter().enumerate() {\n                validate_integer_input(v, i, $name, $min as f64, $max as f64)?;\n                data.push(v as $t);\n            }\n            tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data)\n                .map(|a| a.into_tvalue())\n                .map_err(|e| DsperseError::Onnx(format!(\"input tensor: {e}\")))\n        }};\n    }\n    if dt == f32::datum_type() {\n        let mut data: Vec<f32> = Vec::with_capacity(input_data.len());\n        for (i, &v) in input_data.iter().enumerate() {\n            reject_non_finite(v, i, \"f32\")?;\n            if v < -f32_max_f64 || v > f32_max_f64 {\n                return Err(DsperseError::Onnx(format!(\n                    \"input[{i}] = {v}: magnitude exceeds representable f32 range [-{f32_max_f64}, {f32_max_f64}]\"\n                )));\n            }\n            data.push(v as f32);\n        }\n        tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data)\n            .map(|a| a.into_tvalue())\n            .map_err(|e| DsperseError::Onnx(format!(\"input tensor: {e}\")))\n    } else if dt == f64::datum_type() {\n        let mut data: Vec<f64> = Vec::with_capacity(input_data.len());\n        for (i, &v) in input_data.iter().enumerate() {\n            reject_non_finite(v, i, \"f64\")?;\n            data.push(v);\n        }\n        tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data)\n            .map(|a| a.into_tvalue())\n            .map_err(|e| DsperseError::Onnx(format!(\"input tensor: {e}\")))\n    } else if dt == u8::datum_type() {\n        build_bounded_int!(u8, \"u8\", u8::MIN, u8::MAX)\n    } else if dt == i8::datum_type() {\n        build_bounded_int!(i8, \"i8\", i8::MIN, i8::MAX)\n    } else if dt == u16::datum_type() {\n        build_bounded_int!(u16, \"u16\", u16::MIN, u16::MAX)\n    } else if dt == i16::datum_type() {\n        build_bounded_int!(i16, \"i16\", i16::MIN, i16::MAX)\n    } else if dt == u32::datum_type() {\n        build_bounded_int!(u32, \"u32\", u32::MIN, u32::MAX)\n    } else if dt == i32::datum_type() {\n        build_bounded_int!(i32, \"i32\", i32::MIN, i32::MAX)\n    } else if dt == u64::datum_type() {\n        let mut data: Vec<u64> = Vec::with_capacity(input_data.len());\n        for (i, &v) in input_data.iter().enumerate() {\n            validate_integer_input(v, i, \"u64\", 0.0, I64_SAFE_BOUND_F64)?;\n            data.push(v as u64);\n        }\n        tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data)\n            .map(|a| a.into_tvalue())\n            .map_err(|e| DsperseError::Onnx(format!(\"input tensor: {e}\")))\n    } else if dt == i64::datum_type() {\n        let mut data: Vec<i64> = Vec::with_capacity(input_data.len());\n        for (i, &v) in input_data.iter().enumerate() {\n            validate_integer_input(v, i, \"i64\", -I64_SAFE_BOUND_F64, I64_SAFE_BOUND_F64)?;\n            data.push(v as i64);\n        }\n        tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data)\n            .map(|a| a.into_tvalue())\n            .map_err(|e| DsperseError::Onnx(format!(\"input tensor: {e}\")))\n    } else if dt == bool::datum_type() {\n        let mut data: Vec<bool> = Vec::with_capacity(input_data.len());\n        for (i, &v) in input_data.iter().enumerate() {\n            reject_non_finite(v, i, \"bool\")?;\n            if v != 0.0 && v != 1.0 {\n                return Err(DsperseError::Onnx(format!(\n                    \"input[{i}] = {v}: bool inputs must be exactly 0 or 1\"\n                )));\n            }\n            data.push(v != 0.0);\n        }\n        tract_ndarray::ArrayD::from_shape_vec(IxDyn(shape), data)\n            .map(|a| a.into_tvalue())\n            .map_err(|e| DsperseError::Onnx(format!(\"input tensor: {e}\")))\n    } else {\n        Err(DsperseError::Onnx(format!(\n            \"unsupported input datum type {dt:?}\"\n        )))\n    }\n}\n\nfn run_single(\n    plan: &Arc<TypedRunnableModel>,\n    input_data: &[f64],\n    shape: &[usize],\n    dt: DatumType,\n) -> Result<TVec<TValue>> {\n    let tv = build_input_tvalue(input_data, shape, dt)?;\n    plan.run(tvec!(tv))\n        .map_err(|e| DsperseError::Onnx(format!(\"inference: {e}\")))\n}\n\npub struct WarmModel {\n    plan: Arc<TypedRunnableModel>,\n    input_shape: Vec<usize>,\n    input_dt: DatumType,\n}\n\nimpl WarmModel {\n    pub fn load(onnx_path: &Path, input_shape: &[usize]) -> Result<Self> {\n        let (plan, input_shape, input_dt) = load_runnable(onnx_path, input_shape)?;\n        Ok(Self {\n            plan,\n            input_shape,\n            input_dt,\n        })\n    }\n\n    pub fn run(&self, input_data: &[f64]) -> Result<(Vec<f64>, Vec<usize>)> {\n        let result = run_single(&self.plan, input_data, &self.input_shape, self.input_dt)?;\n        extract_first_output(&result)\n    }\n}\n\npub fn run_inference(\n    onnx_path: &Path,\n    input_data: &[f64],\n    input_shape: &[usize],\n) -> Result<(Vec<f64>, Vec<usize>)> {\n    let (plan, concrete_shape, input_dt) = load_runnable(onnx_path, input_shape)?;\n    let result = run_single(&plan, input_data, &concrete_shape, input_dt)?;\n    extract_first_output(&result)\n}\n\npub fn run_inference_named(\n    onnx_path: &Path,\n    input_data: &[f64],\n    input_shape: &[usize],\n) -> Result<NamedOutputs> {\n    let model = load_onnx_model(onnx_path)?;\n    let output_names = collect_output_names(&model);\n    let concrete_shape = resolve_concrete_shape(&model, input_shape)?;\n    let input_dt = resolve_input_datum_type(&model, 0)?;\n    match optimize_to_runnable(model, &concrete_shape, input_dt) {\n        Ok(plan) => {\n            let result = run_single(&plan, input_data, &concrete_shape, input_dt)?;\n            zip_named_outputs(&output_names, &result)\n        }\n        Err(_) => {\n            let mut result = run_inference_with_coercion(onnx_path, input_data, &concrete_shape)?;\n            let mut named = NamedOutputs::new();\n            for (i, name) in output_names.iter().enumerate() {\n                let key = format!(\"output_{i}\");\n                if let Some(val) = result.remove(&key) {\n                    named.insert(name.clone(), val);\n                }\n            }\n            Ok(named)\n        }\n    }\n}\n\npub fn run_inference_multi(\n    onnx_path: &Path,\n    inputs: &[(&str, Vec<f64>, Vec<usize>)],\n) -> Result<(Vec<f64>, Vec<usize>)> {\n    let (result, _) = run_multi_inner(onnx_path, inputs)?;\n    extract_first_output(&result)\n}\n\npub fn run_inference_multi_named(\n    onnx_path: &Path,\n    inputs: &[(&str, Vec<f64>, Vec<usize>)],\n) -> Result<NamedOutputs> {\n    let (result, output_names) = run_multi_inner(onnx_path, inputs)?;\n    zip_named_outputs(&output_names, &result)\n}\n\nfn run_multi_inner(\n    onnx_path: &Path,\n    inputs: &[(&str, Vec<f64>, Vec<usize>)],\n) -> Result<(TVec<TValue>, Vec<String>)> {\n    let mut model = load_onnx_model(onnx_path)?;\n\n    let output_names = collect_output_names(&model);\n\n    let mut input_by_name: HashMap<&str, usize> = HashMap::with_capacity(inputs.len());\n    for (idx, (name, _, _)) in inputs.iter().enumerate() {\n        if input_by_name.insert(*name, idx).is_some() {\n            return Err(DsperseError::Onnx(format!(\n                \"duplicate provided input name '{name}'\"\n            )));\n        }\n    }\n\n    let model_input_count = model.inputs.len();\n    let model_input_names: Vec<(usize, String)> = model\n        .inputs\n        .iter()\n        .enumerate()\n        .map(|(i, outlet)| (i, model.nodes[outlet.node].name.clone()))\n        .collect();\n\n    let mut input_order: Vec<Option<usize>> = vec![None; model_input_count];\n    let mut input_dts: Vec<Option<DatumType>> = vec![None; model_input_count];\n    for (i, name) in &model_input_names {\n        if let Some(&provided_idx) = input_by_name.get(name.as_str()) {\n            let dt = resolve_input_datum_type(&model, *i)?;\n            model = model\n                .with_input_fact(*i, InferenceFact::dt_shape(dt, &inputs[provided_idx].2))\n                .map_err(|e| DsperseError::Onnx(format!(\"set input {i} ({name}) shape: {e}\")))?;\n            input_order[*i] = Some(provided_idx);\n            input_dts[*i] = Some(dt);\n        }\n    }\n\n    let unknown_inputs: Vec<&str> = input_by_name\n        .keys()\n        .copied()\n        .filter(|name| !model_input_names.iter().any(|(_, n)| n == *name))\n        .collect();\n    if !unknown_inputs.is_empty() {\n        return Err(DsperseError::Onnx(format!(\n            \"provided inputs not present in model: {unknown_inputs:?}\"\n        )));\n    }\n\n    let model = model\n        .into_typed()\n        .map_err(|e| {\n            let unmatched: Vec<_> = input_order\n                .iter()\n                .enumerate()\n                .filter(|(_, v)| v.is_none())\n                .map(|(i, _)| model_input_names[i].1.as_str())\n                .collect();\n            DsperseError::Onnx(format!(\"type analysis (unmatched: {unmatched:?}): {e}\"))\n        })?\n        .into_optimized()\n        .map_err(|e| DsperseError::Onnx(format!(\"optimize: {e:#}\")))?\n        .into_runnable()\n        .map_err(|e| DsperseError::Onnx(format!(\"make runnable: {e:#}\")))?;\n\n    let mut input_tvs = TVec::new();\n    for (model_idx, idx) in input_order.iter().enumerate() {\n        let provided_idx = idx.ok_or_else(|| {\n            let name = &model_input_names[model_idx].1;\n            DsperseError::Onnx(format!(\n                \"model input {model_idx} ('{name}') not matched to provided tensors\"\n            ))\n        })?;\n        let dt = input_dts[model_idx].ok_or_else(|| {\n            let name = &model_input_names[model_idx].1;\n            DsperseError::Onnx(format!(\n                \"model input {model_idx} ('{name}') has no resolved datum type\"\n            ))\n        })?;\n        let (_, ref data, ref shape) = inputs[provided_idx];\n        input_tvs.push(build_input_tvalue(data, shape, dt)?);\n    }\n\n    let result = model\n        .run(input_tvs)\n        .map_err(|e| DsperseError::Onnx(format!(\"inference: {e}\")))?;\n\n    Ok((result, output_names))\n}\n\nfn collect_output_names(model: &InferenceModel) -> Vec<String> {\n    model\n        .outputs\n        .iter()\n        .map(|outlet| {\n            model\n                .outlet_label(*outlet)\n                .map(String::from)\n                .unwrap_or_else(|| {\n                    format!(\"{}_output_{}\", model.nodes[outlet.node].name, outlet.slot)\n                })\n        })\n        .collect()\n}\n\nconst I64_SAFE_BOUND: i64 = 9_007_199_254_740_992;\n\nfn i64_to_f64_checked(v: i64, label: &str) -> Result<f64> {\n    if v.abs() > I64_SAFE_BOUND {\n        return Err(DsperseError::Onnx(format!(\n            \"{label}: i64 value {v} exceeds IEEE-754 safe integer bound\"\n        )));\n    }\n    Ok(v as f64)\n}\n\nfn u64_to_f64_checked(v: u64, label: &str) -> Result<f64> {\n    if v > I64_SAFE_BOUND as u64 {\n        return Err(DsperseError::Onnx(format!(\n            \"{label}: u64 value {v} exceeds IEEE-754 safe integer bound\"\n        )));\n    }\n    Ok(v as f64)\n}\n\nfn tvalue_to_f64(tv: &TValue, label: &str) -> Result<(Vec<f64>, Vec<usize>)> {\n    let shape = tv.shape().to_vec();\n    let dt = tv.datum_type();\n    let data: Vec<f64> = if dt == f32::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<f32>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| f64::from(v)).collect()\n    } else if dt == f64::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<f64>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().copied().collect()\n    } else if dt == i64::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<i64>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter()\n            .map(|&v| i64_to_f64_checked(v, label))\n            .collect::<Result<Vec<_>>>()?\n    } else if dt == i32::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<i32>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| f64::from(v)).collect()\n    } else if dt == u32::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<u32>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| f64::from(v)).collect()\n    } else if dt == i16::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<i16>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| f64::from(v)).collect()\n    } else if dt == u16::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<u16>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| f64::from(v)).collect()\n    } else if dt == i8::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<i8>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| f64::from(v)).collect()\n    } else if dt == u8::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<u8>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| f64::from(v)).collect()\n    } else if dt == u64::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<u64>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter()\n            .map(|&v| u64_to_f64_checked(v, label))\n            .collect::<Result<Vec<_>>>()?\n    } else if dt == bool::datum_type() {\n        let arr = tv\n            .to_plain_array_view::<bool>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect()\n    } else if dt.is_tdim() {\n        let casted = tv\n            .cast_to::<i64>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: TDim->i64 cast: {e}\")))?;\n        let arr = casted\n            .to_plain_array_view::<i64>()\n            .map_err(|e| DsperseError::Onnx(format!(\"{label}: {e}\")))?;\n        arr.iter()\n            .map(|&v| i64_to_f64_checked(v, label))\n            .collect::<Result<Vec<_>>>()?\n    } else {\n        return Err(DsperseError::Onnx(format!(\n            \"{label}: unsupported datum type {dt:?}\"\n        )));\n    };\n    Ok((data, shape))\n}\n\nfn zip_named_outputs(names: &[String], result: &[TValue]) -> Result<NamedOutputs> {\n    let mut map = HashMap::new();\n    for (i, tv) in result.iter().enumerate() {\n        let (data, shape) = tvalue_to_f64(tv, &format!(\"output {i}\"))?;\n        let name = names\n            .get(i)\n            .cloned()\n            .unwrap_or_else(|| format!(\"output_{i}\"));\n        if map.insert(name.clone(), (data, shape)).is_some() {\n            return Err(DsperseError::Onnx(format!(\n                \"duplicate output name '{name}'\"\n            )));\n        }\n    }\n    Ok(map)\n}\n\nfn extract_first_output(result: &[TValue]) -> Result<(Vec<f64>, Vec<usize>)> {\n    let output = result\n        .first()\n        .ok_or_else(|| DsperseError::Onnx(\"no output from model\".into()))?;\n    tvalue_to_f64(output, \"output tensor\")\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    const TEST_OPS: &[&str] = &[\"Conv\", \"Gemm\", \"MatMul\"];\n\n    #[test]\n    fn run_inference_on_sliced_model() {\n        let models_dir = std::path::PathBuf::from(concat!(\n            env!(\"CARGO_MANIFEST_DIR\"),\n            \"/../../tests/models/net\"\n        ));\n        let model_path = models_dir.join(\"model.onnx\");\n        assert!(\n            model_path.exists(),\n            \"fixture missing: {}\",\n            model_path.display()\n        );\n        let tmp = tempfile::tempdir().unwrap();\n        let meta = crate::slicer::slice_model(&model_path, Some(tmp.path()), None, TEST_OPS, None)\n            .expect(\"slice_model failed\");\n        crate::slicer::materializer::ensure_all_slices_materialized(tmp.path(), &meta)\n            .expect(\"materialization failed\");\n        assert!(!meta.slices.is_empty(), \"model produced zero slices\");\n        let first_slice = &meta.slices[0];\n        let onnx_path = tmp\n            .path()\n            .join(format!(\"slice_0/payload/{}\", first_slice.filename));\n        assert!(\n            onnx_path.exists(),\n            \"sliced ONNX missing: {}\",\n            onnx_path.display()\n        );\n        let input_shape = &first_slice.shape.tensor_shape.input;\n        assert!(\n            !input_shape.is_empty() && !input_shape[0].is_empty(),\n            \"empty input shape\"\n        );\n        let shape: Vec<usize> = input_shape[0].iter().map(|&d| d.max(1) as usize).collect();\n        let elem_count: usize = shape.iter().product();\n        let input_data = vec![0.0f64; elem_count];\n        let result = run_inference(&onnx_path, &input_data, &shape);\n        assert!(result.is_ok());\n        let (output_data, output_shape) = result.unwrap();\n        assert!(!output_data.is_empty());\n        assert!(!output_shape.is_empty());\n    }\n\n    #[test]\n    fn run_inference_nonexistent_model() {\n        let result = run_inference(Path::new(\"/nonexistent/model.onnx\"), &[1.0], &[1]);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn warm_model_load_nonexistent() {\n        let result = WarmModel::load(Path::new(\"/nonexistent/model.onnx\"), &[1, 1, 28, 28]);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn warm_model_load_and_run_on_slice() {\n        let models_dir = std::path::PathBuf::from(concat!(\n            env!(\"CARGO_MANIFEST_DIR\"),\n            \"/../../tests/models/net\"\n        ));\n        let model_path = models_dir.join(\"model.onnx\");\n        assert!(\n            model_path.exists(),\n            \"fixture missing: {}\",\n            model_path.display()\n        );\n        let tmp = tempfile::tempdir().unwrap();\n        let meta = crate::slicer::slice_model(&model_path, Some(tmp.path()), None, TEST_OPS, None)\n            .expect(\"slice_model failed\");\n        crate::slicer::materializer::ensure_all_slices_materialized(tmp.path(), &meta)\n            .expect(\"materialization failed\");\n        assert!(!meta.slices.is_empty(), \"model produced zero slices\");\n        let first_slice = &meta.slices[0];\n        let onnx_path = tmp\n            .path()\n            .join(format!(\"slice_0/payload/{}\", first_slice.filename));\n        assert!(\n            onnx_path.exists(),\n            \"sliced ONNX missing: {}\",\n            onnx_path.display()\n        );\n        let input_shape = &first_slice.shape.tensor_shape.input;\n        assert!(\n            !input_shape.is_empty() && !input_shape[0].is_empty(),\n            \"empty input shape\"\n        );\n        let shape: Vec<usize> = input_shape[0].iter().map(|&d| d.max(1) as usize).collect();\n        let elem_count: usize = shape.iter().product();\n\n        let warm = WarmModel::load(&onnx_path, &shape).expect(\"WarmModel::load failed\");\n        let input = vec![0.0f64; elem_count];\n        let (data1, shape1) = warm.run(&input).unwrap();\n        let (data2, shape2) = warm.run(&input).unwrap();\n        assert!(!data1.is_empty());\n        assert_eq!(shape1, shape2);\n        assert_eq!(data1, data2);\n    }\n\n    #[test]\n    fn zip_named_outputs_empty() {\n        let result = zip_named_outputs(&[], &[]).unwrap();\n        assert!(result.is_empty());\n    }\n\n    #[test]\n    fn extract_first_output_empty() {\n        let result = extract_first_output(&[]);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn build_input_tvalue_respects_declared_dtypes() {\n        let shape = [2usize, 3];\n        let values: Vec<f64> = (0..6).map(|v| v as f64).collect();\n\n        let tv_f32 = build_input_tvalue(&values, &shape, f32::datum_type()).unwrap();\n        assert_eq!(tv_f32.datum_type(), f32::datum_type());\n        assert_eq!(tv_f32.shape(), &shape);\n\n        let tv_u8 = build_input_tvalue(&values, &shape, u8::datum_type()).unwrap();\n        assert_eq!(tv_u8.datum_type(), u8::datum_type());\n\n        let tv_i64 = build_input_tvalue(&values, &shape, i64::datum_type()).unwrap();\n        assert_eq!(tv_i64.datum_type(), i64::datum_type());\n\n        let bool_vals = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];\n        let tv_bool = build_input_tvalue(&bool_vals, &shape, bool::datum_type()).unwrap();\n        assert_eq!(tv_bool.datum_type(), bool::datum_type());\n        let view = tv_bool.to_plain_array_view::<bool>().unwrap();\n        assert_eq!(\n            view.iter().copied().collect::<Vec<_>>(),\n            vec![false, true, false, true, false, true]\n        );\n\n        let unsupported = build_input_tvalue(&values, &shape, DatumType::String);\n        assert!(unsupported.is_err());\n    }\n\n    #[test]\n    fn build_input_tvalue_rejects_non_finite() {\n        let shape = [3usize];\n        for dt in [\n            f32::datum_type(),\n            f64::datum_type(),\n            u8::datum_type(),\n            i64::datum_type(),\n            bool::datum_type(),\n        ] {\n            for bad in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {\n                let err = build_input_tvalue(&[0.0, bad, 1.0], &shape, dt).unwrap_err();\n                let msg = format!(\"{err:?}\");\n                assert!(\n                    msg.contains(\"non-finite\"),\n                    \"expected non-finite error for dt={dt:?} val={bad}, got {msg}\"\n                );\n            }\n        }\n    }\n\n    #[test]\n    fn build_input_tvalue_rejects_fractional_for_integer_dtypes() {\n        let shape = [2usize];\n        for dt in [\n            u8::datum_type(),\n            i8::datum_type(),\n            u32::datum_type(),\n            i32::datum_type(),\n            i64::datum_type(),\n            u64::datum_type(),\n        ] {\n            let err = build_input_tvalue(&[0.0, 1.5], &shape, dt).unwrap_err();\n            let msg = format!(\"{err:?}\");\n            assert!(\n                msg.contains(\"fractional\"),\n                \"expected fractional error for dt={dt:?}, got {msg}\"\n            );\n        }\n    }\n\n    #[test]\n    fn build_input_tvalue_rejects_out_of_range_for_integer_dtypes() {\n        let shape = [2usize];\n        let cases: &[(DatumType, f64)] = &[\n            (u8::datum_type(), 256.0),\n            (u8::datum_type(), -1.0),\n            (i8::datum_type(), 128.0),\n            (i8::datum_type(), -129.0),\n            (u16::datum_type(), -1.0),\n            (i16::datum_type(), 32_768.0),\n            (u32::datum_type(), -1.0),\n        ];\n        for (dt, bad) in cases.iter().copied() {\n            let err = build_input_tvalue(&[0.0, bad], &shape, dt).unwrap_err();\n            let msg = format!(\"{err:?}\");\n            assert!(\n                msg.contains(\"outside\"),\n                \"expected range error for dt={dt:?} val={bad}, got {msg}\"\n            );\n        }\n    }\n\n    #[test]\n    fn safe_integer_bound_is_inclusive_on_both_sides() {\n        let shape = [3usize];\n        let bound = I64_SAFE_BOUND as f64;\n        build_input_tvalue(&[0.0, bound, -bound], &shape, i64::datum_type())\n            .expect(\"i64 accepts +/- I64_SAFE_BOUND\");\n        build_input_tvalue(&[0.0, bound, 1.0], &shape, u64::datum_type())\n            .expect(\"u64 accepts I64_SAFE_BOUND\");\n\n        i64_to_f64_checked(I64_SAFE_BOUND, \"i64\")\n            .expect(\"i64_to_f64_checked accepts I64_SAFE_BOUND\");\n        i64_to_f64_checked(-I64_SAFE_BOUND, \"i64\")\n            .expect(\"i64_to_f64_checked accepts -I64_SAFE_BOUND\");\n        u64_to_f64_checked(I64_SAFE_BOUND as u64, \"u64\")\n            .expect(\"u64_to_f64_checked accepts I64_SAFE_BOUND\");\n\n        assert!(i64_to_f64_checked(I64_SAFE_BOUND + 1, \"i64\").is_err());\n        assert!(u64_to_f64_checked(I64_SAFE_BOUND as u64 + 1, \"u64\").is_err());\n    }\n\n    #[test]\n    fn build_input_tvalue_rejects_i64_above_safe_integer_bound() {\n        let shape = [2usize];\n        let unsafe_hi = (I64_SAFE_BOUND as f64) + 1024.0;\n        let err = build_input_tvalue(&[0.0, unsafe_hi], &shape, i64::datum_type()).unwrap_err();\n        let msg = format!(\"{err:?}\");\n        assert!(\n            msg.contains(\"safe integer bound\"),\n            \"expected safe-integer-bound error, got {msg}\"\n        );\n    }\n\n    #[test]\n    fn build_input_tvalue_rejects_finite_f64_outside_f32_range() {\n        let shape = [2usize];\n        for bad in [1.0e40_f64, -1.0e40_f64] {\n            assert!(bad.is_finite());\n            let err = build_input_tvalue(&[0.0, bad], &shape, f32::datum_type()).unwrap_err();\n            let msg = format!(\"{err:?}\");\n            assert!(\n                msg.contains(\"representable f32 range\"),\n                \"expected f32-range error for val={bad}, got {msg}\"\n            );\n        }\n        let ok = build_input_tvalue(\n            &[0.0, f32::MAX as f64, -(f32::MAX as f64)],\n            &[3],\n            f32::datum_type(),\n        )\n        .unwrap();\n        let view = ok.to_plain_array_view::<f32>().unwrap();\n        assert!(view.iter().all(|v| v.is_finite()));\n    }\n\n    #[test]\n    fn build_input_tvalue_rejects_non_boolean_for_bool_dtype() {\n        let shape = [2usize];\n        let err = build_input_tvalue(&[0.0, 2.0], &shape, bool::datum_type()).unwrap_err();\n        let msg = format!(\"{err:?}\");\n        assert!(\n            msg.contains(\"bool inputs must be exactly 0 or 1\"),\n            \"expected strict bool error, got {msg}\"\n        );\n    }\n\n    fn write_uint8_cast_to_float_model(path: &Path) {\n        use crate::slicer::onnx_proto;\n        let input = onnx_proto::make_tensor_value_info(\"x\", 2, &[3]); // 2 = UINT8\n        let output = onnx_proto::make_tensor_value_info(\"y\", 1, &[3]); // 1 = FLOAT\n        let cast_to = onnx_proto::make_attribute_int(\"to\", 1);\n        let node = onnx_proto::make_node(\n            \"Cast\",\n            vec![\"x\".to_string()],\n            vec![\"y\".to_string()],\n            vec![cast_to],\n        );\n        let graph = onnx_proto::make_graph(\"g\", vec![node], vec![input], vec![output], vec![]);\n        let model = onnx_proto::make_model(graph, 13);\n        onnx_proto::save_model(&model, path).unwrap();\n    }\n\n    fn write_uint8_identity_model(path: &Path) {\n        use crate::slicer::onnx_proto;\n        let input = onnx_proto::make_tensor_value_info(\"x\", 2, &[3]); // UINT8\n        let output = onnx_proto::make_tensor_value_info(\"y\", 2, &[3]); // UINT8\n        let node = onnx_proto::make_node(\n            \"Identity\",\n            vec![\"x\".to_string()],\n            vec![\"y\".to_string()],\n            vec![],\n        );\n        let graph = onnx_proto::make_graph(\"g\", vec![node], vec![input], vec![output], vec![]);\n        let model = onnx_proto::make_model(graph, 13);\n        onnx_proto::save_model(&model, path).unwrap();\n    }\n\n    #[test]\n    fn warm_model_decodes_uint8_output() {\n        let tmp = tempfile::tempdir().unwrap();\n        let onnx_path = tmp.path().join(\"u8_identity.onnx\");\n        write_uint8_identity_model(&onnx_path);\n\n        let shape = [3usize];\n        let warm = WarmModel::load(&onnx_path, &shape).expect(\"WarmModel::load\");\n        assert_eq!(warm.input_dt, u8::datum_type());\n        let (data, out_shape) = warm.run(&[0.0, 128.0, 255.0]).unwrap();\n        assert_eq!(out_shape, shape.to_vec());\n        assert_eq!(data, vec![0.0, 128.0, 255.0]);\n    }\n\n    #[test]\n    fn tvalue_to_f64_covers_added_integer_dtypes() {\n        fn tv_of<T: Datum>(values: &[T]) -> TValue {\n            let arr =\n                tract_ndarray::ArrayD::from_shape_vec(IxDyn(&[values.len()]), values.to_vec())\n                    .unwrap();\n            arr.into_tvalue()\n        }\n        let (d, s) = tvalue_to_f64(&tv_of::<u8>(&[0, 255]), \"u8\").unwrap();\n        assert_eq!((d, s), (vec![0.0, 255.0], vec![2]));\n        let (d, _) = tvalue_to_f64(&tv_of::<i8>(&[-128, 127]), \"i8\").unwrap();\n        assert_eq!(d, vec![-128.0, 127.0]);\n        let (d, _) = tvalue_to_f64(&tv_of::<u16>(&[0, 65_535]), \"u16\").unwrap();\n        assert_eq!(d, vec![0.0, 65_535.0]);\n        let (d, _) = tvalue_to_f64(&tv_of::<i16>(&[-32_768, 32_767]), \"i16\").unwrap();\n        assert_eq!(d, vec![-32_768.0, 32_767.0]);\n        let (d, _) = tvalue_to_f64(&tv_of::<u32>(&[0, u32::MAX]), \"u32\").unwrap();\n        assert_eq!(d, vec![0.0, u32::MAX as f64]);\n        let (d, _) = tvalue_to_f64(&tv_of::<u64>(&[0, 1_000_000]), \"u64\").unwrap();\n        assert_eq!(d, vec![0.0, 1_000_000.0]);\n\n        let unsafe_hi = (I64_SAFE_BOUND as u64) + 7;\n        let err = tvalue_to_f64(&tv_of::<u64>(&[unsafe_hi]), \"u64\").unwrap_err();\n        assert!(\n            format!(\"{err:?}\").contains(\"safe integer bound\"),\n            \"expected u64 safe-bound error\"\n        );\n    }\n\n    #[test]\n    fn warm_model_runs_non_f32_input_through_planner() {\n        let tmp = tempfile::tempdir().unwrap();\n        let onnx_path = tmp.path().join(\"u8_cast.onnx\");\n        write_uint8_cast_to_float_model(&onnx_path);\n\n        let shape = [3usize];\n        let warm = WarmModel::load(&onnx_path, &shape).expect(\"WarmModel::load\");\n        assert_eq!(warm.input_dt, u8::datum_type());\n        let (data, out_shape) = warm.run(&[0.0, 42.0, 255.0]).unwrap();\n        assert_eq!(out_shape, shape.to_vec());\n        assert_eq!(data, vec![0.0, 42.0, 255.0]);\n\n        // A second call with a value that can't round-trip through u8 must error\n        // from build_input_tvalue before the planner is invoked.\n        let err = warm.run(&[0.0, 256.0, 0.0]).unwrap_err();\n        assert!(format!(\"{err:?}\").contains(\"outside\"));\n    }\n\n    #[test]\n    fn run_inference_multi_honors_per_input_dtype() {\n        let tmp = tempfile::tempdir().unwrap();\n        let onnx_path = tmp.path().join(\"u8_cast.onnx\");\n        write_uint8_cast_to_float_model(&onnx_path);\n\n        let inputs: Vec<(&str, Vec<f64>, Vec<usize>)> = vec![(\"x\", vec![1.0, 2.0, 3.0], vec![3])];\n        let out = run_inference_multi_named(&onnx_path, &inputs).unwrap();\n        let (data, shape) = out.values().next().expect(\"at least one output\");\n        assert_eq!(shape, &vec![3]);\n        assert_eq!(data, &vec![1.0, 2.0, 3.0]);\n    }\n\n    #[test]\n    fn resolve_input_datum_type_reads_concrete_model_dtype() {\n        let tmp = tempfile::tempdir().unwrap();\n        let onnx_path = tmp.path().join(\"u8_cast.onnx\");\n        write_uint8_cast_to_float_model(&onnx_path);\n        let model = load_onnx_model(&onnx_path).unwrap();\n        let dt = resolve_input_datum_type(&model, 0).unwrap();\n        assert_eq!(dt, u8::datum_type());\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/backend/traits.rs",
    "content": "use std::path::Path;\n\nuse crate::error::Result;\n\npub trait ProofBackend: Send + Sync {\n    fn prove(&self, circuit_path: &Path, witness_bytes: &[u8]) -> Result<Vec<u8>>;\n\n    fn verify(&self, circuit_path: &Path, witness_bytes: &[u8], proof_bytes: &[u8])\n    -> Result<bool>;\n\n    fn witness_f64(\n        &self,\n        circuit_path: &Path,\n        activations: &[f64],\n        initializers: &[(Vec<f64>, Vec<usize>)],\n    ) -> Result<Vec<u8>>;\n}\n"
  },
  {
    "path": "crates/dsperse/src/cli/mod.rs",
    "content": "use std::num::NonZeroUsize;\nuse std::path::{Path, PathBuf};\n\nuse clap::{Args, Parser, Subcommand};\n\nuse crate::backend::jstprove::{JstproveBackend, ProofConfig};\nuse crate::error::{DsperseError, Result};\nuse crate::pipeline::{self, RunConfig};\n\nuse jstprove_circuits::api::{ProofConfigError, ProofSystemType as ProofSystem};\n\nfn parse_proof_config(value: &str) -> Result<ProofConfig> {\n    value.parse().map_err(|e: ProofConfigError| {\n        DsperseError::Other(format!(\"invalid --curve '{value}': {e}\"))\n    })\n}\n\npub const VERSION: &str = env!(\"DSPERSE_DISPLAY_VERSION\");\n\n#[derive(Parser)]\n#[command(name = \"dsperse\", about = \"Distributed zkML Toolkit\", version = VERSION)]\npub struct Cli {\n    #[command(subcommand)]\n    pub command: Commands,\n    #[arg(long, default_value = \"warn\", global = true)]\n    pub log_level: String,\n}\n\n#[derive(Subcommand)]\npub enum Commands {\n    Slice(SliceArgs),\n    Combine(CombineArgs),\n    Compile(CompileArgs),\n    Run(RunArgs),\n    Prove(ProveArgs),\n    Verify(VerifyArgs),\n    Package(PackageArgs),\n    Publish(PublishArgs),\n    #[command(name = \"full-run\")]\n    FullRun(FullRunArgs),\n    Analyze(AnalyzeArgs),\n    #[command(name = \"setup-holographic\")]\n    SetupHolographic(SetupHolographicArgs),\n}\n\npub fn dispatch(command: Commands) -> Result<()> {\n    match command {\n        Commands::Slice(args) => cmd_slice(args),\n        Commands::Combine(args) => cmd_combine(args),\n        Commands::Compile(args) => cmd_compile(args),\n        Commands::Run(args) => cmd_run(args),\n        Commands::Prove(args) => cmd_prove(args),\n        Commands::Verify(args) => cmd_verify(args),\n        Commands::Package(args) => cmd_package(args),\n        Commands::Publish(args) => cmd_publish(args),\n        Commands::FullRun(args) => cmd_full_run(args),\n        Commands::Analyze(args) => cmd_analyze(args),\n        Commands::SetupHolographic(args) => cmd_setup_holographic(args),\n    }\n}\n\n#[derive(Args)]\npub struct SliceArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub output_dir: Option<PathBuf>,\n    #[arg(long, default_value = \"512\")]\n    pub tile_size: Option<usize>,\n    #[arg(\n        long,\n        default_value = \"expander\",\n        help = \"Proof system backend (expander or remainder)\"\n    )]\n    pub proof_system: String,\n    #[arg(\n        long,\n        help = \"Comma-separated ONNX op names to compile via the proof backend (default: all supported)\"\n    )]\n    pub circuit_ops: Option<String>,\n    #[arg(\n        long,\n        value_delimiter = ',',\n        help = \"Concrete input shape as comma-separated dims (e.g. 1,3,560,560). Overrides dynamic dimensions.\"\n    )]\n    pub input_shape: Option<Vec<i64>>,\n}\n\n#[derive(Args)]\npub struct CombineArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n}\n\n#[derive(Args)]\npub struct CompileArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(long)]\n    pub layers: Option<String>,\n    #[arg(long, default_value = \"1\")]\n    pub parallel: NonZeroUsize,\n    #[arg(\n        long,\n        default_value_t = true,\n        action = clap::ArgAction::Set,\n        help = \"Compile circuits with weights as inputs for shared circuit reuse (default: true)\"\n    )]\n    pub weights_as_inputs: bool,\n    #[arg(\n        long,\n        default_value = \"expander\",\n        help = \"Proof system backend (expander or remainder)\"\n    )]\n    pub proof_system: String,\n    #[arg(\n        long,\n        help = \"Comma-separated ONNX op names to compile via the proof backend (default: all supported)\"\n    )]\n    pub circuit_ops: Option<String>,\n    #[arg(\n        long = \"proof-config\",\n        visible_alias = \"curve\",\n        default_value = \"bn254_raw\",\n        help = \"Proof config: bn254_raw, goldilocks_basefold, goldilocks_ext2_basefold, goldilocks_ext3_whir, goldilocks_ext4_whir. The --curve alias is retained for backward compatibility and will be removed in a future release.\"\n    )]\n    pub curve: String,\n    #[arg(\n        long,\n        help = \"Skip compilation of slices whose estimated constraint count exceeds this threshold\"\n    )]\n    pub skip_compile_over_size: Option<u64>,\n    #[arg(\n        long,\n        default_value_t = false,\n        action = clap::ArgAction::Set,\n        help = \"Allow the command to exit 0 when individual slices fail to compile.  Failed slices fall back to ONNX execution at run / prove time, producing a partial-coverage proof.  Off by default so CI surfaces real compile regressions.\"\n    )]\n    pub allow_onnx_fallback: bool,\n    #[arg(\n        long,\n        default_value_t = false,\n        action = clap::ArgAction::Set,\n        help = \"After compiling each slice, run holographic GKR setup and persist the verifying key as vk.bin in the bundle directory. Requires --proof-config goldilocks_ext4_whir.\"\n    )]\n    pub holographic: bool,\n}\n\n#[derive(Args)]\npub struct RunArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub input_file: PathBuf,\n    #[arg(long)]\n    pub run_dir: Option<PathBuf>,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(long, default_value = \"1\")]\n    pub parallel: NonZeroUsize,\n    #[arg(long)]\n    pub batch: bool,\n    #[arg(\n        long,\n        help = \"Path to consumer ONNX with fine-tuned weights to inject at inference time\"\n    )]\n    pub weights: Option<PathBuf>,\n    #[arg(\n        long,\n        default_value_t = true,\n        action = clap::ArgAction::Set,\n        help = \"Run inference on combined monolithic ONNX instead of per-slice execution\"\n    )]\n    pub combined: bool,\n}\n\n#[derive(Args)]\npub struct ProveArgs {\n    #[arg(long)]\n    pub run_dir: PathBuf,\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(long, default_value = \"1\")]\n    pub parallel: NonZeroUsize,\n}\n\n#[derive(Args)]\npub struct VerifyArgs {\n    #[arg(long)]\n    pub run_dir: PathBuf,\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(long, default_value = \"1\")]\n    pub parallel: NonZeroUsize,\n}\n\n#[derive(Args)]\npub struct PackageArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(long)]\n    pub output_dir: Option<PathBuf>,\n    #[arg(long)]\n    pub author: Option<String>,\n    #[arg(long)]\n    pub model_version: Option<String>,\n    #[arg(long)]\n    pub model_name: Option<String>,\n    #[arg(long)]\n    pub timeout: Option<u64>,\n    #[arg(\n        long,\n        help = \"Finite field curve used as domain separator in content hashes (bn254, goldilocks, goldilocks_basefold, goldilocks_ext2, goldilocks_whir, goldilocks_whir_pq)\"\n    )]\n    pub curve: Option<String>,\n}\n\n#[derive(Args)]\npub struct PublishArgs {\n    #[arg(long, help = \"Package directory containing manifest.msgpack\")]\n    pub dir: PathBuf,\n    #[arg(long, help = \"Registry base URL\")]\n    pub url: String,\n    #[arg(long, env = \"REGISTRY_AUTH_TOKEN\", hide_env_values = true)]\n    pub auth_token: String,\n    #[arg(long)]\n    pub name: String,\n    #[arg(long, default_value = \"\")]\n    pub description: String,\n    #[arg(long)]\n    pub author: String,\n    #[arg(long, default_value = \"1.0.0\")]\n    pub version: String,\n    #[arg(long, default_value = \"JSTPROVE\")]\n    pub proof_system: String,\n    #[arg(long, default_value = \"3600\")]\n    pub timeout: u64,\n    #[arg(long, default_value_t = false, help = \"Activate model after upload\")]\n    pub activate: bool,\n}\n\n#[derive(Args)]\npub struct FullRunArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub input_file: Option<PathBuf>,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(long)]\n    pub layers: Option<String>,\n    #[arg(\n        long,\n        default_value_t = true,\n        action = clap::ArgAction::Set,\n        help = \"Compile circuits with weights as inputs for shared circuit reuse (default: true)\"\n    )]\n    pub weights_as_inputs: bool,\n    #[arg(long, default_value = \"1\")]\n    pub parallel: NonZeroUsize,\n    #[arg(long)]\n    pub batch: bool,\n    #[arg(\n        long,\n        help = \"Path to consumer ONNX with fine-tuned weights to inject at inference time\"\n    )]\n    pub weights: Option<PathBuf>,\n    #[arg(\n        long,\n        default_value = \"expander\",\n        help = \"Proof system backend (expander or remainder)\"\n    )]\n    pub proof_system: String,\n    #[arg(\n        long,\n        help = \"Comma-separated ONNX op names to compile via the proof backend (default: all supported)\"\n    )]\n    pub circuit_ops: Option<String>,\n    #[arg(\n        long,\n        default_value_t = true,\n        action = clap::ArgAction::Set,\n        help = \"Run inference on combined monolithic ONNX instead of per-slice execution\"\n    )]\n    pub combined: bool,\n    #[arg(\n        long = \"proof-config\",\n        visible_alias = \"curve\",\n        default_value = \"bn254_raw\",\n        help = \"Proof config: bn254_raw, goldilocks_basefold, goldilocks_ext2_basefold, goldilocks_ext3_whir, goldilocks_ext4_whir. The --curve alias is retained for backward compatibility and will be removed in a future release.\"\n    )]\n    pub curve: String,\n    #[arg(\n        long,\n        help = \"Skip compilation of slices whose estimated constraint count exceeds this threshold\"\n    )]\n    pub skip_compile_over_size: Option<u64>,\n    #[arg(\n        long,\n        default_value_t = false,\n        action = clap::ArgAction::Set,\n        help = \"Allow full-run to proceed when individual slices fail to compile.  Failed slices fall back to ONNX execution, producing a partial-coverage proof.  Off by default so CI surfaces real compile regressions.\"\n    )]\n    pub allow_onnx_fallback: bool,\n    #[arg(\n        long,\n        default_value_t = false,\n        action = clap::ArgAction::Set,\n        help = \"After compiling each slice, run holographic GKR setup and persist the verifying key as vk.bin in the bundle directory. Requires --proof-config goldilocks_ext4_whir.\"\n    )]\n    pub holographic: bool,\n}\n\n#[derive(Args)]\npub struct SetupHolographicArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(long, default_value = \"1\")]\n    pub parallel: NonZeroUsize,\n    #[arg(\n        long,\n        default_value_t = false,\n        action = clap::ArgAction::Set,\n        help = \"Re-run setup and overwrite vk.bin even when the bundle already has one\"\n    )]\n    pub overwrite: bool,\n}\n\nstruct CircuitOps(Vec<String>);\n\nimpl CircuitOps {\n    fn as_refs(&self) -> Vec<&str> {\n        self.0.iter().map(String::as_str).collect()\n    }\n}\n\nfn resolve_circuit_ops(proof_system_str: &str, circuit_ops: Option<&str>) -> Result<CircuitOps> {\n    let ps: ProofSystem =\n        proof_system_str\n            .parse()\n            .map_err(|e: jstprove_circuits::api::ProofSystemParseError| {\n                DsperseError::Other(e.to_string())\n            })?;\n\n    let supported = ps.supported_ops();\n\n    let ops = match circuit_ops {\n        None => supported.iter().map(|s| (*s).to_string()).collect(),\n        Some(spec) => {\n            let requested: Vec<String> = spec\n                .split(',')\n                .map(|s| s.trim().to_string())\n                .filter(|s| !s.is_empty())\n                .collect();\n            if requested.is_empty() {\n                return Err(DsperseError::Other(\n                    \"empty --circuit-ops; provide at least one op or omit the flag to use all supported ops\".into(),\n                ));\n            }\n            for op in &requested {\n                if !supported.contains(&op.as_str()) {\n                    return Err(DsperseError::Other(format!(\n                        \"op {op:?} is not supported by proof system {ps}. Supported: {supported:?}\"\n                    )));\n                }\n            }\n            requested\n        }\n    };\n    Ok(CircuitOps(ops))\n}\n\nfn resolve_slices_dir(slices_dir: Option<PathBuf>, model_dir: &Path) -> PathBuf {\n    slices_dir.unwrap_or_else(|| model_dir.join(\"slices\"))\n}\n\npub fn cmd_slice(args: SliceArgs) -> Result<()> {\n    let model_path = args.model_dir.join(\"model.onnx\");\n    if !model_path.exists() {\n        return Err(DsperseError::Slicer(format!(\n            \"model.onnx not found in {}\",\n            args.model_dir.display()\n        )));\n    }\n    let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?;\n    let metadata = crate::slicer::slice_model(\n        &model_path,\n        args.output_dir.as_deref(),\n        args.tile_size,\n        &ops.as_refs(),\n        args.input_shape.as_deref(),\n    )?;\n    tracing::info!(slices = metadata.slices.len(), \"slicing complete\");\n    Ok(())\n}\n\npub fn cmd_combine(args: CombineArgs) -> Result<()> {\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n    let meta = pipeline::runner::load_model_metadata(&slices_dir)?;\n    let path = crate::slicer::combiner::materialize_combined_to_disk(&slices_dir, &meta)?;\n    tracing::info!(path = %path.display(), \"combined ONNX materialized\");\n    Ok(())\n}\n\npub fn cmd_compile(args: CompileArgs) -> Result<()> {\n    let proof_config = parse_proof_config(&args.curve)?;\n    let backend = JstproveBackend::new();\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n\n    let layers = args\n        .layers\n        .as_ref()\n        .map(|s| parse_index_spec(s))\n        .transpose()?;\n\n    let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?;\n\n    let report = pipeline::compile_slices(\n        &slices_dir,\n        &backend,\n        proof_config,\n        args.parallel.get(),\n        args.weights_as_inputs,\n        layers.as_deref(),\n        &ops.as_refs(),\n        args.skip_compile_over_size,\n        args.holographic,\n    )?;\n    if args.allow_onnx_fallback {\n        Ok(())\n    } else {\n        report.ok_if_no_failures().map(|_| ())\n    }\n}\n\npub fn cmd_run(args: RunArgs) -> Result<()> {\n    if !args.input_file.is_file() {\n        return Err(DsperseError::Other(format!(\n            \"input file not found: {}\",\n            args.input_file.display()\n        )));\n    }\n\n    let backend = JstproveBackend::new();\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n\n    let run_dir = args\n        .run_dir\n        .unwrap_or_else(|| args.model_dir.join(\"run\").join(format!(\"run_{}\", run_id())));\n\n    let config = RunConfig {\n        parallel: args.parallel.get(),\n        batch: args.batch,\n        weights_onnx: args.weights,\n        combined: args.combined,\n    };\n\n    pipeline::run_inference(&slices_dir, &args.input_file, &run_dir, &backend, &config)?;\n    Ok(())\n}\n\npub fn cmd_prove(args: ProveArgs) -> Result<()> {\n    let backend = JstproveBackend::new();\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n\n    pipeline::prove_run(&args.run_dir, &slices_dir, &backend, args.parallel.get())?;\n    Ok(())\n}\n\npub fn cmd_verify(args: VerifyArgs) -> Result<()> {\n    let backend = JstproveBackend::new();\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n\n    pipeline::verify_run(&args.run_dir, &slices_dir, &backend, args.parallel.get())?;\n    Ok(())\n}\n\npub fn cmd_package(args: PackageArgs) -> Result<()> {\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n    let output_dir = args\n        .output_dir\n        .unwrap_or_else(|| args.model_dir.join(\"package\"));\n\n    let config = pipeline::packager::PackageConfig {\n        output_dir,\n        author: args.author,\n        model_version: args.model_version,\n        model_name: args.model_name,\n        timeout: args.timeout,\n        curve: args.curve,\n    };\n\n    let result = pipeline::packager::package_content_addressed(&slices_dir, &config)?;\n\n    tracing::info!(\n        components = result.component_count,\n        weight_biases = result.wb_count,\n        total_bytes = result.total_size,\n        manifest = %result.manifest_path.display(),\n        \"content-addressed packaging complete\"\n    );\n\n    Ok(())\n}\n\npub fn cmd_publish(args: PublishArgs) -> Result<()> {\n    let config = pipeline::publisher::PublishConfig {\n        api_url: args.url,\n        auth_token: args.auth_token,\n        name: args.name,\n        description: args.description,\n        author: args.author,\n        version: args.version,\n        proof_system: args.proof_system,\n        timeout: args.timeout,\n        activate: args.activate,\n    };\n\n    let result = match pipeline::publisher::publish(&args.dir, &config) {\n        Ok(r) => r,\n        Err(e) => {\n            tracing::error!(error = %e, \"publish failed\");\n            return Err(e);\n        }\n    };\n\n    tracing::info!(\n        model_id = %result.model_id,\n        components_uploaded = result.components_uploaded,\n        components_skipped = result.components_skipped,\n        weights_uploaded = result.weights_uploaded,\n        weights_skipped = result.weights_skipped,\n        \"publish complete\"\n    );\n\n    Ok(())\n}\n\npub fn cmd_full_run(args: FullRunArgs) -> Result<()> {\n    let proof_config = parse_proof_config(&args.curve)?;\n    let backend = JstproveBackend::new();\n\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n\n    let input_file = args\n        .input_file\n        .unwrap_or_else(|| args.model_dir.join(crate::utils::paths::INPUT_FILE));\n\n    if !input_file.is_file() {\n        return Err(DsperseError::Other(format!(\n            \"input file not found: {}\",\n            input_file.display()\n        )));\n    }\n\n    if args.weights.is_some() && !args.weights_as_inputs {\n        return Err(DsperseError::Other(\n            \"--weights requires --weights-as-inputs during compilation\".into(),\n        ));\n    }\n\n    let layers = args\n        .layers\n        .as_ref()\n        .map(|s| parse_index_spec(s))\n        .transpose()?;\n\n    let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?;\n\n    tracing::info!(\"compiling slices\");\n    let report = pipeline::compile_slices(\n        &slices_dir,\n        &backend,\n        proof_config,\n        args.parallel.get(),\n        args.weights_as_inputs,\n        layers.as_deref(),\n        &ops.as_refs(),\n        args.skip_compile_over_size,\n        args.holographic,\n    )?;\n    if !args.allow_onnx_fallback {\n        report.ok_if_no_failures()?;\n    }\n\n    let run_dir = args.model_dir.join(\"run\").join(format!(\"run_{}\", run_id()));\n\n    let config = RunConfig {\n        parallel: args.parallel.get(),\n        batch: args.batch,\n        weights_onnx: args.weights,\n        combined: args.combined,\n    };\n\n    tracing::info!(\"running inference\");\n    pipeline::run_inference(&slices_dir, &input_file, &run_dir, &backend, &config)?;\n\n    tracing::info!(\"proving\");\n    pipeline::prove_run(&run_dir, &slices_dir, &backend, args.parallel.get())?;\n\n    tracing::info!(\"verifying\");\n    pipeline::verify_run(&run_dir, &slices_dir, &backend, args.parallel.get())?;\n\n    tracing::info!(run_dir = %run_dir.display(), \"full run complete\");\n    Ok(())\n}\n\npub fn cmd_setup_holographic(args: SetupHolographicArgs) -> Result<()> {\n    let backend = JstproveBackend::new();\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n\n    let report = pipeline::setup_holographic_for_slices(\n        &slices_dir,\n        &backend,\n        args.parallel.get(),\n        args.overwrite,\n    )?;\n\n    tracing::info!(\n        processed = report.processed,\n        skipped = report.skipped_already_present,\n        failed = report.failed.len(),\n        \"holographic setup complete\"\n    );\n\n    report.ok_if_no_failures().map(|_| ())\n}\n\n#[derive(Args)]\npub struct AnalyzeArgs {\n    #[arg(long)]\n    pub model_dir: PathBuf,\n    #[arg(long)]\n    pub slices_dir: Option<PathBuf>,\n    #[arg(\n        long,\n        default_value = \"expander\",\n        help = \"Proof system backend (expander or remainder)\"\n    )]\n    pub proof_system: String,\n    #[arg(\n        long,\n        help = \"Comma-separated ONNX op names to compile via the proof backend\"\n    )]\n    pub circuit_ops: Option<String>,\n    #[arg(\n        long,\n        help = \"Skip slices whose estimated constraint count exceeds this\"\n    )]\n    pub skip_compile_over_size: Option<u64>,\n    #[arg(\n        long = \"proof-config\",\n        visible_alias = \"curve\",\n        default_value = \"bn254_raw\",\n        help = \"Proof config for circuit signature computation\"\n    )]\n    pub proof_config: String,\n    #[arg(\n        long,\n        default_value_t = AnalyzeFormat::Table,\n        value_enum,\n        help = \"Output format\"\n    )]\n    pub format: AnalyzeFormat,\n}\n\n#[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum)]\npub enum AnalyzeFormat {\n    Table,\n    Json,\n}\n\nimpl std::fmt::Display for AnalyzeFormat {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            Self::Table => f.write_str(\"table\"),\n            Self::Json => f.write_str(\"json\"),\n        }\n    }\n}\n\nfn cmd_analyze(args: AnalyzeArgs) -> Result<()> {\n    let slices_dir = resolve_slices_dir(args.slices_dir, &args.model_dir);\n    let ops = resolve_circuit_ops(&args.proof_system, args.circuit_ops.as_deref())?;\n    // Validate proof_config through the same parser cmd_compile and\n    // cmd_full_run use so a typo in --proof-config fails fast with a\n    // \"unknown proof config 'foo'\" message rather than silently\n    // producing signatures under an unintended curve.\n    let proof_config = parse_proof_config(&args.proof_config)?;\n    let proof_config_name = proof_config.to_string();\n\n    let reports = pipeline::analyze_slices(\n        &slices_dir,\n        &ops.as_refs(),\n        args.skip_compile_over_size,\n        Some(proof_config_name.as_str()),\n    )?;\n\n    if matches!(args.format, AnalyzeFormat::Json) {\n        println!(\n            \"{}\",\n            serde_json::to_string_pretty(&reports)\n                .map_err(|e| DsperseError::Other(e.to_string()))?\n        );\n    } else {\n        let hdr_ops = \"OPS\";\n        println!(\n            \"{:<8} {:<10} {:<28} {:<14} {:<6} {:<6} {:<6} {:<12} {hdr_ops}\",\n            \"SLICE\", \"BACKEND\", \"REASON\", \"EST.CONSTR\", \"TILED\", \"CHSPL\", \"DMSPL\", \"SIGNATURE\"\n        );\n        println!(\"{}\", \"-\".repeat(120));\n\n        let mut jstprove_count = 0usize;\n        let mut onnx_count = 0usize;\n        let mut missing_count = 0usize;\n        let mut total_constraints: u64 = 0;\n        let mut unique_sigs: std::collections::HashSet<String> = std::collections::HashSet::new();\n\n        for r in &reports {\n            let est = r\n                .estimated_constraints\n                .map(|c| format!(\"{c}\"))\n                .unwrap_or_default();\n            let sig = r\n                .circuit_signature\n                .as_deref()\n                .map(|s| &s[..12.min(s.len())])\n                .unwrap_or(\"\");\n            println!(\n                \"{:<8} {:<10} {:<28} {:<14} {:<6} {:<6} {:<6} {:<12} {}\",\n                r.index,\n                r.backend,\n                r.reason,\n                est,\n                r.tiled,\n                r.channel_split,\n                r.dim_split,\n                sig,\n                r.ops,\n            );\n            match r.backend.as_str() {\n                \"jstprove\" => jstprove_count += 1,\n                \"onnx\" => onnx_count += 1,\n                \"missing\" => missing_count += 1,\n                other => {\n                    tracing::warn!(\n                        slice = r.index,\n                        backend = other,\n                        \"analyze: unknown backend classification; not counted\"\n                    );\n                }\n            }\n            if let Some(c) = r.estimated_constraints {\n                total_constraints += c;\n            }\n            if let Some(ref s) = r.circuit_signature {\n                unique_sigs.insert(s.clone());\n            }\n        }\n\n        println!(\"{}\", \"-\".repeat(120));\n        println!(\n            \"total: {} slices | jstprove: {} | onnx: {} | missing: {} | unique circuits: {} | total constraints: {}\",\n            reports.len(),\n            jstprove_count,\n            onnx_count,\n            missing_count,\n            unique_sigs.len(),\n            total_constraints,\n        );\n    }\n\n    Ok(())\n}\n\nfn parse_index_spec(spec: &str) -> Result<Vec<usize>> {\n    let mut layers = Vec::new();\n    for part in spec.split(',') {\n        let part = part.trim();\n        if part.is_empty() {\n            continue;\n        }\n        if let Some((start, end)) = part.split_once('-') {\n            let s: usize = start.trim().parse().map_err(|_| {\n                DsperseError::Other(format!(\"invalid index spec range start: {start:?}\"))\n            })?;\n            let e: usize = end.trim().parse().map_err(|_| {\n                DsperseError::Other(format!(\"invalid index spec range end: {end:?}\"))\n            })?;\n            if s > e {\n                return Err(DsperseError::Other(format!(\n                    \"invalid index spec range: start {s} > end {e}\"\n                )));\n            }\n            layers.extend(s..=e);\n        } else {\n            let n: usize = part\n                .parse()\n                .map_err(|_| DsperseError::Other(format!(\"invalid index spec token: {part:?}\")))?;\n            layers.push(n);\n        }\n    }\n    if layers.is_empty() {\n        return Err(DsperseError::Other(\"empty index spec\".into()));\n    }\n    Ok(layers)\n}\n\nfn run_id() -> String {\n    let now = std::time::SystemTime::now()\n        .duration_since(std::time::UNIX_EPOCH)\n        .unwrap_or_default();\n    let uuid = uuid::Uuid::new_v4();\n    format!(\"{}_{}\", now.as_secs(), uuid.as_simple())\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use clap::Parser;\n\n    #[test]\n    fn parse_index_spec_single() {\n        assert_eq!(parse_index_spec(\"3\").unwrap(), vec![3]);\n    }\n\n    #[test]\n    fn parse_index_spec_multiple() {\n        assert_eq!(parse_index_spec(\"1,3,5\").unwrap(), vec![1, 3, 5]);\n    }\n\n    #[test]\n    fn parse_index_spec_range() {\n        assert_eq!(parse_index_spec(\"2-5\").unwrap(), vec![2, 3, 4, 5]);\n    }\n\n    #[test]\n    fn parse_index_spec_mixed() {\n        assert_eq!(parse_index_spec(\"0,2-4,7\").unwrap(), vec![0, 2, 3, 4, 7]);\n    }\n\n    #[test]\n    fn parse_index_spec_whitespace_tolerance() {\n        assert_eq!(parse_index_spec(\" 1 , 2 - 3 \").unwrap(), vec![1, 2, 3]);\n    }\n\n    #[test]\n    fn parse_index_spec_empty_rejected() {\n        assert!(parse_index_spec(\"\").is_err());\n    }\n\n    #[test]\n    fn parse_index_spec_invalid_token() {\n        assert!(parse_index_spec(\"abc\").is_err());\n    }\n\n    #[test]\n    fn parse_index_spec_reversed_range() {\n        assert!(parse_index_spec(\"5-2\").is_err());\n    }\n\n    #[test]\n    fn parse_index_spec_trailing_comma() {\n        assert_eq!(parse_index_spec(\"1,2,\").unwrap(), vec![1, 2]);\n    }\n\n    #[test]\n    fn run_id_format() {\n        let id = run_id();\n        let parts: Vec<&str> = id.splitn(2, '_').collect();\n        assert_eq!(parts.len(), 2);\n        assert!(parts[0].parse::<u64>().is_ok());\n        assert_eq!(parts[1].len(), 32);\n    }\n\n    #[test]\n    fn run_id_unique() {\n        let id1 = run_id();\n        let id2 = run_id();\n        assert_ne!(id1, id2);\n    }\n\n    #[test]\n    fn cli_parse_slice_command() {\n        let cli = Cli::parse_from([\"dsperse\", \"slice\", \"--model-dir\", \"/tmp/model\"]);\n        assert!(matches!(cli.command, Commands::Slice(_)));\n    }\n\n    #[test]\n    fn cli_parse_run_command() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"run\",\n            \"--model-dir\",\n            \"/tmp/model\",\n            \"--input-file\",\n            \"/tmp/input.json\",\n        ]);\n        assert!(matches!(cli.command, Commands::Run(_)));\n    }\n\n    #[test]\n    fn cli_log_level_default() {\n        let cli = Cli::parse_from([\"dsperse\", \"slice\", \"--model-dir\", \"/tmp\"]);\n        assert_eq!(cli.log_level, \"warn\");\n    }\n\n    #[test]\n    fn cli_log_level_override() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"--log-level\",\n            \"debug\",\n            \"slice\",\n            \"--model-dir\",\n            \"/tmp\",\n        ]);\n        assert_eq!(cli.log_level, \"debug\");\n    }\n\n    #[test]\n    fn cli_compile_with_layers() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"compile\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--layers\",\n            \"0,2-4\",\n        ]);\n        if let Commands::Compile(args) = cli.command {\n            assert_eq!(args.layers.as_deref(), Some(\"0,2-4\"));\n        } else {\n            panic!(\"expected Compile\");\n        }\n    }\n\n    #[test]\n    fn cli_run_parallel() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"run\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--input-file\",\n            \"/tmp/in.json\",\n            \"--parallel\",\n            \"4\",\n        ]);\n        if let Commands::Run(args) = cli.command {\n            assert_eq!(args.parallel.get(), 4);\n        } else {\n            panic!(\"expected Run\");\n        }\n    }\n\n    #[test]\n    fn cli_slice_with_tile_size() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"slice\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--tile-size\",\n            \"1024\",\n        ]);\n        if let Commands::Slice(args) = cli.command {\n            assert_eq!(args.tile_size, Some(1024));\n        } else {\n            panic!(\"expected Slice\");\n        }\n    }\n\n    #[test]\n    fn cli_parse_combine_command() {\n        let cli = Cli::parse_from([\"dsperse\", \"combine\", \"--model-dir\", \"/tmp/model\"]);\n        assert!(matches!(cli.command, Commands::Combine(_)));\n    }\n\n    #[test]\n    fn cli_parse_combine_with_slices_dir() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"combine\",\n            \"--model-dir\",\n            \"/tmp/model\",\n            \"--slices-dir\",\n            \"/tmp/slices\",\n        ]);\n        if let Commands::Combine(args) = cli.command {\n            assert_eq!(\n                args.slices_dir,\n                Some(std::path::PathBuf::from(\"/tmp/slices\"))\n            );\n        } else {\n            panic!(\"expected Combine\");\n        }\n    }\n\n    #[test]\n    fn cli_run_combined_default_true() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"run\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--input-file\",\n            \"/tmp/in.json\",\n        ]);\n        if let Commands::Run(args) = cli.command {\n            assert!(args.combined);\n        } else {\n            panic!(\"expected Run\");\n        }\n    }\n\n    #[test]\n    fn cli_run_combined_explicit_false() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"run\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--input-file\",\n            \"/tmp/in.json\",\n            \"--combined\",\n            \"false\",\n        ]);\n        if let Commands::Run(args) = cli.command {\n            assert!(!args.combined);\n        } else {\n            panic!(\"expected Run\");\n        }\n    }\n\n    #[test]\n    fn cli_compile_holographic_default_false() {\n        let cli = Cli::parse_from([\"dsperse\", \"compile\", \"--model-dir\", \"/tmp\"]);\n        if let Commands::Compile(args) = cli.command {\n            assert!(!args.holographic);\n        } else {\n            panic!(\"expected Compile\");\n        }\n    }\n\n    #[test]\n    fn cli_compile_holographic_explicit_true() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"compile\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--holographic\",\n            \"true\",\n        ]);\n        if let Commands::Compile(args) = cli.command {\n            assert!(args.holographic);\n        } else {\n            panic!(\"expected Compile\");\n        }\n    }\n\n    #[test]\n    fn cli_full_run_holographic_explicit_true() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"full-run\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--holographic\",\n            \"true\",\n        ]);\n        if let Commands::FullRun(args) = cli.command {\n            assert!(args.holographic);\n        } else {\n            panic!(\"expected FullRun\");\n        }\n    }\n\n    #[test]\n    fn cli_setup_holographic_command() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"setup-holographic\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--parallel\",\n            \"4\",\n        ]);\n        if let Commands::SetupHolographic(args) = cli.command {\n            assert_eq!(args.parallel.get(), 4);\n            assert!(!args.overwrite);\n        } else {\n            panic!(\"expected SetupHolographic\");\n        }\n    }\n\n    #[test]\n    fn cli_setup_holographic_overwrite() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"setup-holographic\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--overwrite\",\n            \"true\",\n        ]);\n        if let Commands::SetupHolographic(args) = cli.command {\n            assert!(args.overwrite);\n        } else {\n            panic!(\"expected SetupHolographic\");\n        }\n    }\n\n    #[test]\n    fn cli_compile_wai_default_true() {\n        let cli = Cli::parse_from([\"dsperse\", \"compile\", \"--model-dir\", \"/tmp\"]);\n        if let Commands::Compile(args) = cli.command {\n            assert!(args.weights_as_inputs);\n        } else {\n            panic!(\"expected Compile\");\n        }\n    }\n\n    #[test]\n    fn cli_compile_wai_explicit_false() {\n        let cli = Cli::parse_from([\n            \"dsperse\",\n            \"compile\",\n            \"--model-dir\",\n            \"/tmp\",\n            \"--weights-as-inputs\",\n            \"false\",\n        ]);\n        if let Commands::Compile(args) = cli.command {\n            assert!(!args.weights_as_inputs);\n        } else {\n            panic!(\"expected Compile\");\n        }\n    }\n\n    #[test]\n    fn resolve_circuit_ops_invalid_proof_system() {\n        let result = resolve_circuit_ops(\"nonexistent\", None);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn resolve_circuit_ops_unsupported_op() {\n        let result = resolve_circuit_ops(\"expander\", Some(\"FakeOp\"));\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn resolve_circuit_ops_empty_spec_rejected() {\n        let result = resolve_circuit_ops(\"expander\", Some(\"\"));\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn resolve_circuit_ops_whitespace_only_spec_rejected() {\n        let result = resolve_circuit_ops(\"expander\", Some(\" ,  , \"));\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn resolve_circuit_ops_valid_specific_ops() {\n        let supported = ProofSystem::Expander.supported_ops();\n        assert!(!supported.is_empty());\n        let first_op = supported[0];\n        let ops = resolve_circuit_ops(\"expander\", Some(first_op)).unwrap();\n        assert_eq!(ops.as_refs(), vec![first_op]);\n    }\n\n    #[test]\n    fn resolve_circuit_ops_none_returns_all() {\n        let ops = resolve_circuit_ops(\"expander\", None).unwrap();\n        let expected: Vec<&str> = ProofSystem::Expander.supported_ops().to_vec();\n        assert_eq!(ops.as_refs(), expected);\n    }\n\n    #[test]\n    fn resolve_slices_dir_custom_path() {\n        let result = resolve_slices_dir(Some(PathBuf::from(\"/custom\")), Path::new(\"/model\"));\n        assert_eq!(result, PathBuf::from(\"/custom\"));\n    }\n\n    #[test]\n    fn resolve_slices_dir_default_fallback() {\n        let model_dir = Path::new(\"/model\");\n        let result = resolve_slices_dir(None, model_dir);\n        assert_eq!(result, model_dir.join(\"slices\"));\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/converter.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::Path;\n\nuse jstprove_circuits::api::{\n    self, ArchitectureType as Architecture, CircuitParamsType as CircuitParams, WANDBType as WANDB,\n};\n\nuse crate::error::{DsperseError, Result};\n\npub fn prepare_jstprove_artifacts(\n    onnx_path: &Path,\n    weights_as_inputs: bool,\n) -> Result<(CircuitParams, Architecture, WANDB)> {\n    prepare_jstprove_artifacts_filtered(onnx_path, weights_as_inputs, &HashSet::new(), None)\n}\n\npub fn prepare_jstprove_artifacts_filtered(\n    onnx_path: &Path,\n    weights_as_inputs: bool,\n    exclude_from_wai: &HashSet<String>,\n    traced_shapes: Option<&HashMap<String, Vec<i64>>>,\n) -> Result<(CircuitParams, Architecture, WANDB)> {\n    let meta = match traced_shapes {\n        Some(shapes) => {\n            let converted: HashMap<String, Vec<usize>> = shapes\n                .iter()\n                .map(|(k, v)| {\n                    (\n                        k.clone(),\n                        v.iter()\n                            .map(|&d| if d < 0 { 1 } else { d as usize })\n                            .collect(),\n                    )\n                })\n                .collect();\n            api::generate_metadata_with_shapes(onnx_path, converted)\n        }\n        None => api::generate_metadata(onnx_path),\n    }\n    .map_err(|e| DsperseError::Pipeline(format!(\"ONNX metadata generation: {e:#}\")))?;\n\n    let mut params = meta.circuit_params;\n    if weights_as_inputs {\n        api::populate_wai_inputs(&mut params, &meta.wandb, exclude_from_wai)\n            .map_err(|e| DsperseError::Pipeline(format!(\"WAI input population: {e}\")))?;\n    }\n\n    Ok((params, meta.architecture, meta.wandb))\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn prepare_jstprove_artifacts_nonexistent_model() {\n        let result = prepare_jstprove_artifacts(Path::new(\"/nonexistent.onnx\"), false);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn prepare_jstprove_artifacts_with_weights_as_inputs() {\n        let result = prepare_jstprove_artifacts(Path::new(\"/nonexistent.onnx\"), true);\n        assert!(result.is_err());\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/error.rs",
    "content": "use std::path::PathBuf;\n\npub type Result<T> = std::result::Result<T, DsperseError>;\n\n#[derive(Debug, thiserror::Error)]\npub enum DsperseError {\n    #[error(\"I/O error at {}: {source}\", .path.file_name().and_then(|n| n.to_str()).unwrap_or(\"<unknown>\"))]\n    Io {\n        source: std::io::Error,\n        path: PathBuf,\n    },\n\n    #[error(\"msgpack encode error: {0}\")]\n    MsgpackEncode(#[from] rmp_serde::encode::Error),\n\n    #[error(\"msgpack decode error: {0}\")]\n    MsgpackDecode(#[from] rmp_serde::decode::Error),\n\n    #[error(\"ONNX error: {0}\")]\n    Onnx(String),\n\n    #[error(\"backend error: {0}\")]\n    Backend(String),\n\n    #[error(\"slicer error: {0}\")]\n    Slicer(String),\n\n    #[error(\"archive error: {0}\")]\n    Archive(String),\n\n    #[error(\"metadata error: {0}\")]\n    Metadata(String),\n\n    #[error(\"pipeline error: {0}\")]\n    Pipeline(String),\n\n    #[error(\"{0}\")]\n    Other(String),\n}\n\nimpl DsperseError {\n    pub fn io(source: std::io::Error, path: impl Into<PathBuf>) -> Self {\n        Self::Io {\n            source,\n            path: path.into(),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/lib.rs",
    "content": "pub mod backend;\npub mod cli;\npub mod converter;\npub mod error;\npub mod pipeline;\npub mod schema;\npub mod slicer;\npub mod utils;\npub mod version;\n\n#[cfg(feature = \"python\")]\nmod python;\n"
  },
  {
    "path": "crates/dsperse/src/main.rs",
    "content": "use clap::Parser;\nuse tracing_subscriber::EnvFilter;\n\nuse dsperse::cli;\n\nfn main() {\n    let parsed = cli::Cli::parse();\n\n    tracing_subscriber::fmt()\n        .with_env_filter(\n            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&parsed.log_level)),\n        )\n        .init();\n\n    eprintln!(\"dsperse {}\", cli::VERSION);\n\n    if let Err(e) = cli::dispatch(parsed.command) {\n        tracing::error!(\"{e}\");\n        std::process::exit(1);\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/channel_split.rs",
    "content": "use std::collections::HashMap;\nuse std::path::Path;\n\nuse ndarray::{Array4, ArrayD, s};\n\nuse super::runner::{generate_wai_witness, resolve_circuit_path_optional, run_onnx_inference};\nuse super::tensor_store::TensorStore;\nuse crate::backend::jstprove::JstproveBackend;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::execution::{ExecutionInfo, ExecutionMethod};\nuse crate::schema::tiling::{ChannelGroupInfo, ChannelSplitInfo};\nuse crate::slicer::onnx_proto::TensorProto;\nuse crate::utils::io::read_msgpack;\nuse crate::utils::paths::resolve_relative_path;\n\npub(crate) fn reshape_channel_split_output(\n    arr: ArrayD<f64>,\n    target_shape: Option<&[i64]>,\n) -> Result<ArrayD<f64>> {\n    let Some(raw) = target_shape else {\n        return Ok(arr);\n    };\n    let target: Vec<usize> = raw\n        .iter()\n        .map(|&d| {\n            usize::try_from(d).map_err(|_| {\n                DsperseError::Pipeline(format!(\"negative dimension {d} in output_shape\"))\n            })\n        })\n        .collect::<Result<Vec<_>>>()?;\n    if arr.shape() == target.as_slice() {\n        return Ok(arr);\n    }\n    let actual_shape: Vec<usize> = arr.shape().to_vec();\n    let actual_elems: usize = actual_shape.iter().product();\n    let target_elems: usize = target.iter().product();\n    if actual_elems != target_elems {\n        return Err(DsperseError::Pipeline(format!(\n            \"channel_split output element count mismatch: \\\n             actual {actual_elems} (shape {actual_shape:?}) vs target {target_elems} (shape {target:?})\"\n        )));\n    }\n    arr.into_shape_with_order(ndarray::IxDyn(&target))\n        .map_err(|e| {\n            DsperseError::Pipeline(format!(\n                \"channel_split output reshape from {actual_shape:?} to {target:?}: {e}\",\n            ))\n        })\n}\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn execute_channel_split(\n    slices_dir: &Path,\n    slice_run_dir: &Path,\n    slice_id: &str,\n    cs: &ChannelSplitInfo,\n    target_shape: Option<&[i64]>,\n    tensor_cache: &TensorStore,\n    backend: &JstproveBackend,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<crate::schema::execution::StrategyOutput> {\n    let input_arr = tensor_cache.get(&cs.input_name)?.clone();\n\n    let (input_4d, n, h) = if input_arr.ndim() == 4 {\n        let s = input_arr.shape();\n        let n = s[0];\n        if n != 1 {\n            return Err(DsperseError::Pipeline(format!(\n                \"channel split: batch size {n} not supported, expected 1\"\n            )));\n        }\n        let h = s[2];\n        let arr =\n            Array4::from_shape_vec((n, s[1], s[2], s[3]), input_arr.iter().copied().collect())\n                .map_err(|e| DsperseError::Pipeline(format!(\"channel split reshape: {e}\")))?;\n        (arr, n, h)\n    } else {\n        let n = 1usize;\n        let input_flat: Vec<f64> = input_arr.iter().copied().collect();\n        let total_elements = input_flat.len();\n        let nc = n * cs.c_in;\n        if nc > 0 && !total_elements.is_multiple_of(nc) {\n            return Err(DsperseError::Pipeline(format!(\n                \"channel split reshape: total_elements {total_elements} not divisible by n*c_in ({nc})\"\n            )));\n        }\n        let spatial = if cs.c_in > 0 && total_elements > 0 {\n            total_elements / nc\n        } else {\n            cs.h * cs.w\n        };\n        let h = cs.h.max(1);\n        if spatial > 0 && h > 0 && spatial % h != 0 {\n            return Err(DsperseError::Pipeline(format!(\n                \"channel split reshape: spatial {spatial} not divisible by h={h}\"\n            )));\n        }\n        let w = if spatial > 0 && h > 0 {\n            spatial / h\n        } else {\n            cs.w.max(1)\n        };\n        let arr = Array4::from_shape_vec((n, cs.c_in, h, w), input_flat)\n            .map_err(|e| DsperseError::Pipeline(format!(\"channel split reshape: {e}\")))?;\n        (arr, n, h)\n    };\n\n    let mut accumulated: Option<Array4<f64>> = None;\n\n    tracing::info!(\n        slice = %slice_id,\n        num_groups = cs.groups.len(),\n        \"channel split execution\"\n    );\n\n    let n_channels = input_4d.shape()[1];\n    for group in &cs.groups {\n        if group.c_end > n_channels || group.c_start > group.c_end {\n            return Err(DsperseError::Pipeline(format!(\n                \"channel group {} bounds [{}, {}) exceed channel dimension {}\",\n                group.group_idx, group.c_start, group.c_end, n_channels\n            )));\n        }\n        let group_input = input_4d\n            .slice(s![.., group.c_start..group.c_end, .., ..])\n            .to_owned();\n        let group_input_dyn = group_input.into_dyn();\n\n        let group_dir = slice_run_dir.join(format!(\"group_{}\", group.group_idx));\n        std::fs::create_dir_all(&group_dir).map_err(|e| DsperseError::io(e, &group_dir))?;\n\n        let group_output = execute_channel_group(\n            slices_dir,\n            &group_dir,\n            group,\n            &group_input_dyn,\n            backend,\n            donor_init_map,\n        )?;\n\n        let group_4d = if group_output.ndim() == 4 {\n            let s = group_output.shape();\n            Array4::from_shape_vec(\n                (s[0], s[1], s[2], s[3]),\n                group_output.iter().copied().collect(),\n            )\n            .map_err(|e| DsperseError::Pipeline(format!(\"group output reshape: {e}\")))?\n        } else {\n            let group_flat: Vec<f64> = group_output.iter().copied().collect();\n            let (out_h, out_w) = if cs.out_h > 0 && cs.out_w > 0 {\n                (cs.out_h, cs.out_w)\n            } else if cs.c_out > 0 {\n                let out_spatial = group_flat.len() / (n * cs.c_out);\n                if h > 0 && out_spatial > 0 && out_spatial.is_multiple_of(h) {\n                    (h, out_spatial / h)\n                } else {\n                    return Err(DsperseError::Pipeline(format!(\n                        \"cannot determine spatial layout for channel_split output: {} elements, c_out={}, set out_h/out_w in metadata\",\n                        group_flat.len(),\n                        cs.c_out\n                    )));\n                }\n            } else {\n                return Err(DsperseError::Pipeline(\"channel split c_out is 0\".into()));\n            };\n            if n * cs.c_out * out_h * out_w != group_flat.len() {\n                return Err(DsperseError::Pipeline(format!(\n                    \"group output reshape mismatch: expected {} elements (n={}, c_out={}, h={}, w={}), got {}\",\n                    n * cs.c_out * out_h * out_w,\n                    n,\n                    cs.c_out,\n                    out_h,\n                    out_w,\n                    group_flat.len()\n                )));\n            }\n            Array4::from_shape_vec((n, cs.c_out, out_h, out_w), group_flat)\n                .map_err(|e| DsperseError::Pipeline(format!(\"group output reshape: {e}\")))?\n        };\n\n        accumulated = Some(match accumulated {\n            Some(acc) => {\n                if acc.shape() != group_4d.shape() {\n                    return Err(DsperseError::Pipeline(format!(\n                        \"channel group {} shape {:?} does not match accumulator shape {:?}\",\n                        group.group_idx,\n                        group_4d.shape(),\n                        acc.shape()\n                    )));\n                }\n                acc + &group_4d\n            }\n            None => group_4d,\n        });\n    }\n\n    if let Some(ref bias_path_str) = cs.bias_path {\n        let bias_file = resolve_relative_path(slices_dir, bias_path_str)?;\n        if !bias_file.exists() {\n            return Err(DsperseError::Pipeline(format!(\n                \"configured bias file not found: {} (bias_path={bias_path_str})\",\n                bias_file.display()\n            )));\n        }\n        let bias_data = read_msgpack(&bias_file)?;\n        let bias_flat = crate::utils::io::flatten_nested_list(&bias_data);\n        if bias_flat.len() != cs.c_out {\n            return Err(DsperseError::Pipeline(format!(\n                \"bias length {} does not match c_out {}\",\n                bias_flat.len(),\n                cs.c_out\n            )));\n        }\n        if let Some(ref mut acc) = accumulated {\n            for ((_, c, _, _), val) in acc.indexed_iter_mut() {\n                *val += bias_flat[c];\n            }\n        }\n    }\n\n    let output = match accumulated {\n        Some(acc) => reshape_channel_split_output(acc.into_dyn(), target_shape)?,\n        None => {\n            return Err(DsperseError::Pipeline(format!(\n                \"channel_split produced no output for '{}'\",\n                cs.output_name\n            )));\n        }\n    };\n\n    Ok(crate::schema::execution::StrategyOutput {\n        info: ExecutionInfo {\n            method: ExecutionMethod::ChannelSplit,\n            success: true,\n            error: None,\n            witness_file: None,\n            tile_exec_infos: Vec::new(),\n        },\n        outputs: vec![(cs.output_name.clone(), output)],\n    })\n}\n\nfn execute_channel_group(\n    slices_dir: &Path,\n    group_dir: &Path,\n    group: &ChannelGroupInfo,\n    group_input: &ArrayD<f64>,\n    backend: &JstproveBackend,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<ArrayD<f64>> {\n    let onnx_path = resolve_relative_path(slices_dir, &group.path)?;\n\n    let patched_onnx = if let Some(map) = donor_init_map {\n        Some(crate::slicer::onnx_proto::build_patched_onnx(\n            &onnx_path, map,\n        )?)\n    } else {\n        None\n    };\n    let effective_onnx = patched_onnx\n        .as_ref()\n        .map_or(onnx_path.as_path(), |t| t.path());\n\n    if let Some(circuit_path) =\n        resolve_circuit_path_optional(slices_dir, group.jstprove_circuit_path.as_deref())?\n    {\n        let params = backend.load_params(&circuit_path)?;\n        let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs);\n\n        if donor_init_map.is_some() && !is_wai {\n            return Err(DsperseError::Pipeline(format!(\n                \"group_{}: consumer weights require circuits compiled with --weights-as-inputs\",\n                group.group_idx\n            )));\n        }\n\n        let output_tensor = run_onnx_inference(effective_onnx, group_input)?;\n\n        let flat: Vec<f64> = group_input.iter().copied().collect();\n        let witness_bytes = if is_wai {\n            generate_wai_witness(\n                backend,\n                &circuit_path,\n                &onnx_path,\n                donor_init_map,\n                params.as_ref().unwrap(),\n                &flat,\n            )?\n        } else {\n            backend.witness_f64(&circuit_path, &flat, &[])?\n        };\n\n        let witness_path = group_dir.join(crate::utils::paths::WITNESS_FILE);\n        std::fs::write(&witness_path, &witness_bytes)\n            .map_err(|e| DsperseError::io(e, &witness_path))?;\n\n        Ok(output_tensor)\n    } else {\n        run_onnx_inference(effective_onnx, group_input)\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/combined.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::{Path, PathBuf};\n\nuse ndarray::{ArrayD, IxDyn};\n\nuse super::incremental::SliceWork;\nuse super::runner::{build_execution_chain, build_run_metadata, load_model_metadata};\nuse super::strategy::ExecutionStrategy;\nuse super::tensor_store::TensorStore;\nuse crate::backend::onnx::NamedOutputs;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::execution::{ExecutionChain, RunMetadata};\nuse crate::schema::metadata::ModelMetadata;\n\npub struct CombinedRun {\n    tensor_cache: TensorStore,\n    model_meta: ModelMetadata,\n    run_meta: RunMetadata,\n    execution_chain: ExecutionChain,\n    slices_dir: PathBuf,\n    pending_slices: HashSet<String>,\n    failed_slices: HashSet<String>,\n}\n\nimpl CombinedRun {\n    pub fn new(slices_dir: &Path, input: ArrayD<f64>) -> Result<Self> {\n        let model_meta = load_model_metadata(slices_dir)?;\n\n        let combined_path =\n            crate::slicer::combiner::ensure_combined_materialized(slices_dir, &model_meta)?;\n\n        crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, &model_meta)?;\n\n        let first_slice = model_meta\n            .slices\n            .first()\n            .ok_or_else(|| DsperseError::Pipeline(\"model has no slices\".into()))?;\n        let declared_inputs = &first_slice.dependencies.filtered_inputs;\n        if declared_inputs.is_empty() {\n            return Err(DsperseError::Pipeline(\n                \"first slice has no input dependency\".into(),\n            ));\n        }\n\n        let named_outputs = run_combined_onnx(&combined_path, &input, declared_inputs)?;\n\n        let mut tensor_cache = TensorStore::new();\n        for (name, (data, shape)) in &named_outputs {\n            let arr = ArrayD::from_shape_vec(IxDyn(shape), data.clone())\n                .map_err(|e| DsperseError::Pipeline(format!(\"output reshape '{name}': {e}\")))?;\n            tensor_cache.put(name.clone(), arr);\n        }\n        for name in declared_inputs {\n            if !tensor_cache.contains(name) {\n                tensor_cache.put(name.clone(), input.clone());\n            }\n        }\n\n        // Seed the tensor_cache with any initializer-backed tensor\n        // the slice metadata references.  The slicer's constant-\n        // folding passes can turn intermediate tensors (e.g. a\n        // Transpose over a constant) into initializers in the\n        // transformed graph, while leaving downstream slice\n        // metadata pointing at the original tensor name.  ORT\n        // does not emit those names among its named outputs (they\n        // are not declared as graph outputs of combined.onnx and\n        // have no producing node), so without this seed the\n        // subsequent `tensor_cache.get` in `all_circuit_work` fails\n        // with `tensor '<name>' not found in store` and the whole\n        // run aborts before a single DSlice gets dispatched.\n        seed_tensor_cache_from_initializers(&combined_path, &model_meta, &mut tensor_cache)?;\n\n        let chain = build_execution_chain(&model_meta, slices_dir)?;\n        let run_meta = build_run_metadata(&model_meta, slices_dir, &chain)?;\n\n        let mut pending_slices = HashSet::new();\n        for slice in &model_meta.slices {\n            let slice_id = format!(\"slice_{}\", slice.index);\n            let node = chain.nodes.get(&slice_id).ok_or_else(|| {\n                DsperseError::Pipeline(format!(\"execution chain missing node for {slice_id}\"))\n            })?;\n            if node.use_circuit {\n                pending_slices.insert(slice_id);\n            }\n        }\n\n        tracing::info!(\n            total_slices = model_meta.slices.len(),\n            circuit_slices = pending_slices.len(),\n            cached_tensors = tensor_cache.len(),\n            \"combined inference complete, all circuit work queued\"\n        );\n\n        Ok(Self {\n            tensor_cache,\n            model_meta,\n            run_meta,\n            execution_chain: chain,\n            slices_dir: slices_dir.to_path_buf(),\n            pending_slices,\n            failed_slices: HashSet::new(),\n        })\n    }\n\n    pub fn all_circuit_work(&self) -> Result<Vec<SliceWork>> {\n        let mut work_items = Vec::with_capacity(self.pending_slices.len());\n\n        for slice in &self.model_meta.slices {\n            let slice_id = format!(\"slice_{}\", slice.index);\n            if !self.pending_slices.contains(&slice_id) {\n                continue;\n            }\n\n            let node = self.execution_chain.nodes.get(&slice_id).ok_or_else(|| {\n                DsperseError::Pipeline(format!(\"execution chain missing node for {slice_id}\"))\n            })?;\n\n            let meta = self.run_meta.slices.get(&slice_id).ok_or_else(|| {\n                DsperseError::Pipeline(format!(\"run metadata missing slice {slice_id}\"))\n            })?;\n\n            let strategy = ExecutionStrategy::from_metadata(meta, node.use_circuit)?;\n            let (input, named_inputs) = match strategy {\n                ExecutionStrategy::ChannelSplit(cs) => {\n                    let t = self.tensor_cache.get(&cs.input_name)?.clone();\n                    (t, Vec::new())\n                }\n                ExecutionStrategy::DimSplit(ds) => {\n                    let t = self.tensor_cache.get(&ds.input_name)?.clone();\n                    (t, Vec::new())\n                }\n                ExecutionStrategy::Tiled(tiling) => {\n                    let t = self.tensor_cache.get(&tiling.input_name)?.clone();\n                    (t, Vec::new())\n                }\n                ExecutionStrategy::Single { .. } => {\n                    let filtered = &meta.dependencies.filtered_inputs;\n                    let mut named = Vec::with_capacity(filtered.len());\n                    let mut flat_elems: Vec<f64> = Vec::new();\n                    for name in filtered {\n                        let arr = self.tensor_cache.get(name)?;\n                        named.push((name.clone(), arr.clone()));\n                        flat_elems.extend(arr.iter());\n                    }\n                    let concatenated = ndarray::ArrayD::from_shape_vec(\n                        ndarray::IxDyn(&[flat_elems.len()]),\n                        flat_elems,\n                    )\n                    .map_err(|e| DsperseError::Pipeline(format!(\"flatten inputs: {e}\")))?;\n                    (concatenated, named)\n                }\n            };\n\n            work_items.push(SliceWork {\n                slice_id,\n                input,\n                named_inputs,\n                backend: node.backend,\n                use_circuit: node.use_circuit,\n                tiling: meta.tiling.clone(),\n                channel_split: meta.channel_split.clone(),\n                circuit_path: node.circuit_path.clone(),\n                onnx_path: node.onnx_path.clone(),\n                slice_meta: meta.clone(),\n            });\n        }\n\n        Ok(work_items)\n    }\n\n    pub fn mark_slice_done(&mut self, slice_id: &str) -> bool {\n        self.pending_slices.remove(slice_id)\n    }\n\n    pub fn mark_slice_failed(&mut self, slice_id: &str) -> bool {\n        let was_pending = self.pending_slices.remove(slice_id);\n        if was_pending {\n            self.failed_slices.insert(slice_id.to_string());\n        }\n        was_pending\n    }\n\n    pub fn is_slice_failed(&self, slice_id: &str) -> bool {\n        self.failed_slices.contains(slice_id)\n    }\n\n    pub fn failed_count(&self) -> usize {\n        self.failed_slices.len()\n    }\n\n    pub fn is_complete(&self) -> bool {\n        self.pending_slices.is_empty()\n    }\n\n    pub fn model_meta(&self) -> &ModelMetadata {\n        &self.model_meta\n    }\n\n    pub fn final_output(&self) -> Option<&ArrayD<f64>> {\n        let last_slice = self.model_meta.slices.last()?;\n        let slice_id = format!(\"slice_{}\", last_slice.index);\n        let meta = self.run_meta.slices.get(&slice_id)?;\n\n        let strategy = ExecutionStrategy::from_metadata(meta, false).ok()?;\n        match strategy.output_name() {\n            Some(name) => self.tensor_cache.try_get(name),\n            None => {\n                let output_name = meta.dependencies.output.first()?;\n                self.tensor_cache.try_get(output_name)\n            }\n        }\n    }\n\n    pub fn expected_slice_outputs(&self, slice_id: &str) -> Option<Vec<f64>> {\n        let meta = self.run_meta.slices.get(slice_id)?;\n        let output_names = &meta.dependencies.output;\n        self.outputs_for_names(output_names)\n    }\n\n    pub fn outputs_for_names(&self, names: &[String]) -> Option<Vec<f64>> {\n        let mut flat = Vec::new();\n        for name in names {\n            let tensor = self.tensor_cache.try_get(name)?;\n            flat.extend(tensor.iter());\n        }\n        if flat.is_empty() { None } else { Some(flat) }\n    }\n\n    pub fn slice_tile_counts(&self) -> (usize, usize, HashMap<String, usize>) {\n        let total_slices = self.model_meta.slices.len();\n        let mut map = HashMap::with_capacity(total_slices);\n        let mut total_tiles = 0usize;\n        for s in &self.model_meta.slices {\n            let tiles = s.tiling.as_ref().map(|t| t.num_tiles).unwrap_or(1);\n            map.insert(format!(\"slice_{}\", s.index), tiles);\n            total_tiles += tiles;\n        }\n        (total_slices, total_tiles, map)\n    }\n\n    pub fn slices_dir(&self) -> &Path {\n        &self.slices_dir\n    }\n\n    pub fn pending_count(&self) -> usize {\n        self.pending_slices.len()\n    }\n}\n\nfn run_combined_onnx(\n    combined_path: &Path,\n    input: &ArrayD<f64>,\n    declared_inputs: &[String],\n) -> Result<NamedOutputs> {\n    if declared_inputs.len() == 1 {\n        let input_flat: Vec<f64> = input.iter().copied().collect();\n        let input_shape = input.shape();\n        crate::backend::onnx::run_inference_named(combined_path, &input_flat, input_shape)\n    } else {\n        Err(DsperseError::Pipeline(format!(\n            \"combined mode requires single input, got {}\",\n            declared_inputs.len()\n        )))\n    }\n}\n\n/// Populate `tensor_cache` with any combined-graph initializer\n/// whose name appears in slice metadata as a `filtered_input` or a\n/// declared `output`.  Without this, a slice that depends on a\n/// constant-folded tensor (one the slicer turned from a node\n/// output into an initializer) would fail at the\n/// `tensor_cache.get(name)` call in `all_circuit_work` even though\n/// the value is right there in the combined ONNX.\nfn seed_tensor_cache_from_initializers(\n    combined_path: &Path,\n    model_meta: &ModelMetadata,\n    tensor_cache: &mut TensorStore,\n) -> Result<()> {\n    let needed: HashSet<&str> = model_meta\n        .slices\n        .iter()\n        .flat_map(|s| {\n            s.dependencies\n                .filtered_inputs\n                .iter()\n                .chain(s.dependencies.output.iter())\n        })\n        .map(String::as_str)\n        .collect();\n    if needed.is_empty() {\n        return Ok(());\n    }\n\n    let model = crate::slicer::onnx_proto::load_model(combined_path)?;\n    let graph = match &model.graph {\n        Some(g) => g,\n        None => return Ok(()),\n    };\n\n    let mut seeded = 0usize;\n    for init in &graph.initializer {\n        if !needed.contains(init.name.as_str()) {\n            continue;\n        }\n        if tensor_cache.contains(&init.name) {\n            continue;\n        }\n        // Negative dims would silently wrap to huge positive\n        // values via `as usize`; reject up front so a malformed\n        // initialiser surfaces an error here instead of\n        // allocating a multi-petabyte array below.\n        let shape: Vec<usize> = match init\n            .dims\n            .iter()\n            .map(|&d| usize::try_from(d))\n            .collect::<std::result::Result<Vec<_>, _>>()\n        {\n            Ok(s) => s,\n            Err(e) => {\n                tracing::debug!(\n                    name = %init.name,\n                    dims = ?init.dims,\n                    error = %e,\n                    \"skipping initializer-backed slice tensor: invalid (negative) dimension\"\n                );\n                continue;\n            }\n        };\n        // Use checked_mul so an arithmetic overflow surfaces as a\n        // skip (and the slice executor downstream produces a\n        // clearer error if it actually needed the value), instead\n        // of wrapping silently and mis-comparing against\n        // `data.len()`.\n        let expected: Option<usize> = shape.iter().try_fold(1usize, |acc, &d| acc.checked_mul(d));\n        let Some(expected) = expected else {\n            tracing::debug!(\n                name = %init.name,\n                dims = ?init.dims,\n                \"skipping initializer-backed slice tensor: shape product overflowed usize\"\n            );\n            continue;\n        };\n        // Decode straight to f64 so DOUBLE / INT64 initialisers\n        // keep their full precision -- the previous f32-then-widen\n        // chain truncated DOUBLE mantissas and silently lost\n        // precision on INT64 magnitudes outside f32's exact range.\n        let data: Vec<f64> = crate::slicer::onnx_proto::tensor_to_f64(init);\n        if data.len() != expected {\n            // Skip rather than fail: an initialiser whose declared\n            // shape doesn't match its element count can still be\n            // useful elsewhere (some quantised tensors store packed\n            // bytes), but we cannot reshape it into ArrayD<f64>\n            // here without guessing.  Leave it to the slice ONNX\n            // executor to surface a clearer error if it actually\n            // needs the value.\n            tracing::debug!(\n                name = %init.name,\n                declared_shape = ?shape,\n                declared_elements = expected,\n                actual_elements = data.len(),\n                \"skipping initializer-backed slice tensor: declared shape != element count\"\n            );\n            continue;\n        }\n        let arr = ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| {\n            DsperseError::Pipeline(format!(\n                \"seed initializer-backed tensor '{}' from combined.onnx: {e}\",\n                init.name\n            ))\n        })?;\n        tensor_cache.put(init.name.clone(), arr);\n        seeded += 1;\n    }\n    if seeded > 0 {\n        tracing::info!(\n            seeded,\n            \"seeded tensor_cache with constant-folded slice-input initializers\"\n        );\n    }\n    Ok(())\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/compiler.rs",
    "content": "use std::collections::HashMap;\nuse std::path::{Path, PathBuf};\n\nuse rayon::prelude::*;\n\nuse crate::backend::jstprove::JstproveBackend;\nuse crate::converter;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::metadata::ModelMetadata;\nuse crate::slicer::autotiler::estimate_slice_constraints;\nuse crate::slicer::onnx_proto;\nuse crate::utils::paths::{find_metadata_path, slice_dir_path};\n\ntype CircuitCache = std::sync::Mutex<HashMap<String, PathBuf>>;\n\nenum CompileOutcome {\n    Compiled,\n    CompiledChannelSplit {\n        group_circuits: Vec<(usize, String)>,\n    },\n    CompiledDimSplit,\n    Skipped,\n    SkippedOverSize {\n        estimated: u64,\n        threshold: u64,\n    },\n}\n\n/// Summary of a compile_slices invocation.  The pass returns Ok\n/// even when individual slice compilations fail, so callers must\n/// inspect `failed` to decide whether to proceed (e.g. allow\n/// partial-coverage ONNX fallback) or abort.  Keeping the\n/// compiled count explicit lets the CLI / analyze command\n/// report a structured summary instead of inferring success from\n/// log lines.\n#[derive(Debug, Default)]\npub struct CompileReport {\n    pub compiled: usize,\n    pub failed: Vec<(usize, DsperseError)>,\n}\n\nimpl CompileReport {\n    /// Return Ok(self) when every slice compiled cleanly.  Otherwise\n    /// return a generic Pipeline error; callers layer their own\n    /// actionable guidance on top (the CLI mentions its\n    /// --allow-onnx-fallback flag, the Python binding mentions the\n    /// `allow_onnx_fallback` keyword).  Keeping the library message\n    /// surface-agnostic avoids leaking CLI conventions into the\n    /// Python / Rust API error stream.\n    pub fn ok_if_no_failures(self) -> Result<Self> {\n        if self.failed.is_empty() {\n            Ok(self)\n        } else {\n            Err(DsperseError::Pipeline(format!(\n                \"compile_slices: {} slice(s) failed to compile; the caller must opt in to partial coverage before proceeding\",\n                self.failed.len()\n            )))\n        }\n    }\n}\n\n/// Backfill split metadata fields that only become resolvable after\n/// slicing (channel_split.groups populated from disk,\n/// dim_split.template_path inferred from the materialized template\n/// ONNX), and strip dim_split entries whose template could not be\n/// materialized.  Called from both compile_slices and analyze_slices\n/// so the two classifications agree on what actually counts as a\n/// channel- or dim-split slice.  Persists the normalised metadata\n/// back to disk when any field changes.\nfn normalize_split_metadata(\n    slices_dir: &Path,\n    meta_path: &Path,\n    metadata: &mut ModelMetadata,\n) -> Result<()> {\n    if metadata.original_model_path.is_some() {\n        crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, metadata)?;\n    }\n\n    let mut metadata_dirty = false;\n    for slice in &mut metadata.slices {\n        if let Some(ref mut cs) = slice.channel_split\n            && cs.groups.is_empty()\n        {\n            let populated = populate_channel_split_groups(slices_dir, slice.index, cs)?;\n            if populated {\n                metadata_dirty = true;\n            }\n        }\n        if let Some(ref mut ds) = slice.dim_split\n            && ds.template_path.is_none()\n        {\n            let tmpl_rel = format!(\"slice_{}/payload/dim_template.onnx\", slice.index);\n            if slices_dir.join(&tmpl_rel).exists() {\n                ds.template_path = Some(tmpl_rel);\n                metadata_dirty = true;\n            }\n        }\n    }\n    // Strip dim_split metadata from slices where template creation\n    // failed (axis-separability rejection, unsupported split kind).\n    // Leaving stale dim_split entries in the metadata causes\n    // downstream runners and the packager to emit bundles that fail\n    // at the strategy validation stage (\"dim_split present but\n    // template_path is missing\").\n    for slice in &mut metadata.slices {\n        if slice\n            .dim_split\n            .as_ref()\n            .is_some_and(|ds| ds.template_path.is_none())\n        {\n            tracing::info!(\n                slice = slice.index,\n                \"stripping dim_split metadata (no template materialized)\"\n            );\n            slice.dim_split = None;\n            metadata_dirty = true;\n        }\n    }\n    if metadata_dirty {\n        metadata.save(meta_path)?;\n        tracing::info!(\"persisted materialized split groups to metadata\");\n    }\n    Ok(())\n}\n\n#[allow(clippy::too_many_arguments)]\npub fn compile_slices(\n    slices_dir: &Path,\n    backend: &JstproveBackend,\n    proof_config: jstprove_circuits::api::ProofConfigType,\n    parallel: usize,\n    weights_as_inputs: bool,\n    layers: Option<&[usize]>,\n    jstprove_ops: &[&str],\n    skip_compile_over_size: Option<u64>,\n    holographic: bool,\n) -> Result<CompileReport> {\n    if holographic && proof_config != jstprove_circuits::api::ProofConfigType::GoldilocksExt4Whir {\n        return Err(DsperseError::Pipeline(format!(\n            \"--holographic requires --proof-config goldilocks_ext4_whir; got {proof_config}\"\n        )));\n    }\n    let meta_path = find_metadata_path(slices_dir).ok_or_else(|| {\n        DsperseError::Metadata(format!(\n            \"no {} found in slices directory\",\n            crate::utils::paths::METADATA_FILE\n        ))\n    })?;\n    let mut metadata = ModelMetadata::load(&meta_path)?;\n    normalize_split_metadata(slices_dir, &meta_path, &mut metadata)?;\n\n    let slices: Vec<_> = metadata\n        .slices\n        .iter()\n        .filter(|s| layers.is_none_or(|l| l.contains(&s.index)))\n        .cloned()\n        .collect();\n\n    tracing::info!(total = slices.len(), \"compiling slices\");\n\n    let exclude_from_wai: std::collections::HashSet<String> =\n        metadata.folded_constant_names.iter().cloned().collect();\n\n    let traced_shapes = metadata.traced_shapes.clone();\n    let traced_ref = traced_shapes.as_ref();\n\n    let pool = rayon::ThreadPoolBuilder::new()\n        .num_threads(parallel)\n        .build()\n        .map_err(|e| DsperseError::Pipeline(format!(\"thread pool: {e}\")))?;\n\n    let compiled_count = std::sync::atomic::AtomicUsize::new(0);\n    let meta_mutex = std::sync::Mutex::new((&mut metadata, false));\n    let errors: std::sync::Mutex<Vec<(usize, DsperseError)>> = std::sync::Mutex::new(Vec::new());\n    let circuit_cache: CircuitCache = std::sync::Mutex::new(HashMap::new());\n\n    pool.install(|| {\n        slices.par_iter().for_each(|slice| {\n            let r = compile_single_slice(\n                slices_dir,\n                slice,\n                backend,\n                proof_config,\n                weights_as_inputs,\n                jstprove_ops,\n                &exclude_from_wai,\n                skip_compile_over_size,\n                &circuit_cache,\n                traced_ref,\n                holographic,\n            );\n            match r {\n                Ok(CompileOutcome::Compiled) => {\n                    let count =\n                        compiled_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;\n                    tracing::info!(slice = slice.index, count, \"compiled\");\n                }\n                Ok(CompileOutcome::CompiledChannelSplit { group_circuits }) => {\n                    let count =\n                        compiled_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;\n                    tracing::info!(\n                        slice = slice.index,\n                        groups = group_circuits.len(),\n                        count,\n                        \"compiled channel split groups\"\n                    );\n                    let mut guard = meta_mutex.lock().unwrap();\n                    let (ref mut meta, ref mut dirty) = *guard;\n                    if let Some(s) = meta.slices.iter_mut().find(|s| s.index == slice.index)\n                        && let Some(ref mut cs) = s.channel_split\n                    {\n                        for (group_idx, circuit_path) in &group_circuits {\n                            if let Some(group) =\n                                cs.groups.iter_mut().find(|g| g.group_idx == *group_idx)\n                            {\n                                group.jstprove_circuit_path = Some(circuit_path.clone());\n                            }\n                        }\n                        *dirty = true;\n                    }\n                }\n                Ok(CompileOutcome::CompiledDimSplit) => {\n                    let count =\n                        compiled_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;\n                    tracing::info!(slice = slice.index, count, \"compiled dim-split template\");\n                    let mut guard = meta_mutex.lock().unwrap();\n                    let (ref mut meta, ref mut dirty) = *guard;\n                    if let Some(s) = meta.slices.iter_mut().find(|s| s.index == slice.index)\n                        && let Some(ref mut ds) = s.dim_split\n                    {\n                        ds.jstprove_circuit_path = Some(format!(\n                            \"slice_{}/jstprove/dim_split/circuit.bundle\",\n                            slice.index\n                        ));\n                        *dirty = true;\n                    }\n                }\n                Ok(CompileOutcome::Skipped) => {\n                    tracing::info!(slice = slice.index, \"skipped (unsupported ops)\")\n                }\n                Ok(CompileOutcome::SkippedOverSize {\n                    estimated,\n                    threshold,\n                }) => {\n                    tracing::info!(\n                        slice = slice.index,\n                        estimated,\n                        threshold,\n                        \"skipped (estimated constraints exceed threshold)\"\n                    )\n                }\n                Err(e) => {\n                    // Per-slice compile failure is recoverable: the\n                    // summary log at the end of compile_slices\n                    // already surfaces the aggregate via warn!, and\n                    // the caller decides whether to continue with\n                    // partial coverage.  Emitting error! here would\n                    // spam CI for an outcome that ok_if_no_failures\n                    // handles structurally.\n                    tracing::warn!(slice = slice.index, error = %e, \"compilation failed\");\n                    errors.lock().unwrap().push((slice.index, e));\n                }\n            }\n        });\n    });\n\n    let errors = errors.into_inner().unwrap();\n    let (metadata, cs_dirty) = meta_mutex.into_inner().unwrap();\n    if cs_dirty {\n        // Swallowing the save failure would let downstream\n        // analyze / run / package observe an in-memory set of\n        // materialised channel / dim-split circuit paths that the\n        // on-disk metadata doesn't know about -- the very problem\n        // normalize_split_metadata exists to prevent.  Propagate.\n        metadata.save(&meta_path)?;\n        tracing::info!(\"persisted split circuit paths to metadata\");\n    }\n    let compiled_count = compiled_count.load(std::sync::atomic::Ordering::Relaxed);\n\n    if errors.is_empty() {\n        tracing::info!(count = compiled_count, \"all slices compiled\");\n    } else {\n        tracing::warn!(\n            compiled = compiled_count,\n            failed = errors.len(),\n            \"compilation completed with errors; failed slices fall back to ONNX execution if the caller allows partial coverage\"\n        );\n        for (idx, e) in &errors {\n            tracing::warn!(slice = idx, error = %e, \"slice compilation failed\");\n        }\n    }\n    Ok(CompileReport {\n        compiled: compiled_count,\n        failed: errors,\n    })\n}\n\nstruct SliceAnalysis {\n    compatible: bool,\n    data_movement_only: bool,\n}\n\nconst DATA_MOVEMENT_OPS: &[&str] = &[\n    \"Reshape\",\n    \"Transpose\",\n    \"Flatten\",\n    \"Squeeze\",\n    \"Unsqueeze\",\n    \"Identity\",\n    \"Concat\",\n    \"Split\",\n    \"Gather\",\n    \"Slice\",\n    \"Expand\",\n    \"Tile\",\n    \"Cast\",\n];\n\nfn analyze_slice_onnx(onnx_path: &Path, jstprove_ops: &[&str]) -> Result<SliceAnalysis> {\n    let model = onnx_proto::load_model(onnx_path)?;\n    let graph = model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Slicer(format!(\"no graph in {}\", onnx_path.display())))?;\n    let compatible = graph\n        .node\n        .iter()\n        .all(|n| jstprove_ops.contains(&n.op_type.as_str()));\n    let data_movement_only = !graph.node.is_empty()\n        && graph\n            .node\n            .iter()\n            .all(|n| DATA_MOVEMENT_OPS.contains(&n.op_type.as_str()));\n    Ok(SliceAnalysis {\n        compatible,\n        data_movement_only,\n    })\n}\n\npub(super) fn compute_circuit_signature(tmpl_path: &Path, curve: Option<&str>) -> Result<String> {\n    use sha2::{Digest, Sha256};\n\n    fn hash_bytes(hasher: &mut Sha256, b: &[u8]) {\n        hasher.update((b.len() as u64).to_le_bytes());\n        hasher.update(b);\n    }\n\n    let model = onnx_proto::load_model(tmpl_path)?;\n    let graph = model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Slicer(\"no graph for signature\".into()))?;\n    let mut hasher = Sha256::new();\n    if let Some(c) = curve {\n        hash_bytes(&mut hasher, c.as_bytes());\n    }\n    hasher.update((graph.node.len() as u64).to_le_bytes());\n    for node in &graph.node {\n        hash_bytes(&mut hasher, node.op_type.as_bytes());\n        hasher.update((node.input.len() as u64).to_le_bytes());\n        for inp in &node.input {\n            hash_bytes(&mut hasher, inp.as_bytes());\n        }\n        hasher.update((node.output.len() as u64).to_le_bytes());\n        for out in &node.output {\n            hash_bytes(&mut hasher, out.as_bytes());\n        }\n        hasher.update((node.attribute.len() as u64).to_le_bytes());\n        for attr in &node.attribute {\n            hash_bytes(&mut hasher, attr.name.as_bytes());\n            hasher.update(attr.r#type.to_le_bytes());\n            hasher.update(attr.i.to_le_bytes());\n            hasher.update(attr.f.to_le_bytes());\n            hash_bytes(&mut hasher, &attr.s);\n            hasher.update((attr.ints.len() as u64).to_le_bytes());\n            for v in &attr.ints {\n                hasher.update(v.to_le_bytes());\n            }\n            hasher.update((attr.floats.len() as u64).to_le_bytes());\n            for v in &attr.floats {\n                hasher.update(v.to_le_bytes());\n            }\n            hasher.update((attr.strings.len() as u64).to_le_bytes());\n            for v in &attr.strings {\n                hash_bytes(&mut hasher, v);\n            }\n        }\n    }\n    let init_names: std::collections::HashSet<&str> =\n        graph.initializer.iter().map(|i| i.name.as_str()).collect();\n    for vi in &graph.input {\n        if init_names.contains(vi.name.as_str()) {\n            continue;\n        }\n        if let Some(shape) = onnx_proto::shape_from_value_info(vi) {\n            hasher.update((shape.len() as u64).to_le_bytes());\n            for d in &shape {\n                hasher.update(d.to_le_bytes());\n            }\n        }\n        if let Some(dt) = onnx_proto::elem_type_from_value_info(vi) {\n            hasher.update(dt.to_le_bytes());\n        }\n    }\n    for vi in &graph.output {\n        if let Some(shape) = onnx_proto::shape_from_value_info(vi) {\n            hasher.update((shape.len() as u64).to_le_bytes());\n            for d in &shape {\n                hasher.update(d.to_le_bytes());\n            }\n        }\n        if let Some(dt) = onnx_proto::elem_type_from_value_info(vi) {\n            hasher.update(dt.to_le_bytes());\n        }\n    }\n    hasher.update((graph.initializer.len() as u64).to_le_bytes());\n    for init in &graph.initializer {\n        hasher.update((init.dims.len() as u64).to_le_bytes());\n        for d in &init.dims {\n            hasher.update(d.to_le_bytes());\n        }\n        hasher.update(init.data_type.to_le_bytes());\n    }\n    let hash = hasher.finalize();\n    Ok(format!(\"{:x}\", hash))\n}\n\n/// Bundle-aware signature used at packaging time. Wraps the ONNX+curve\n/// hash from `compute_circuit_signature` with discriminators pulled\n/// from the compiled bundle so that two packages built from the same\n/// ONNX but under different proof configs, input-binding modes, or\n/// holographic/non-holographic flows land at distinct shas in the\n/// content-addressed registry. The compile-time cache lookups in\n/// `compile_single_slice` continue to use `compute_circuit_signature`\n/// directly because they key on pre-compile state where no bundle\n/// exists yet.\npub(super) fn compute_bundle_signature(\n    tmpl_path: &Path,\n    curve: Option<&str>,\n    bundle_dir: &Path,\n) -> Result<String> {\n    use sha2::{Digest, Sha256};\n\n    let base = compute_circuit_signature(tmpl_path, curve)?;\n    let mut hasher = Sha256::new();\n    hasher.update(base.as_bytes());\n    // Stability contract: any change to the on-wire / in-hash layout\n    // of the bytes mixed in below will silently re-shuffle every\n    // content-addressed component sha. The three inputs that must\n    // stay byte-stable are\n    //   * `jstprove_circuits::proof_config::ProofConfig::config_id()`\n    //     (CONFIG_ID integers documented in proof_config.rs),\n    //   * `StampedProofConfig::version` (u32, per\n    //     ProofConfig::current_version),\n    //   * `CircuitParams::weights_as_inputs` serialization (bool).\n    // If any of those change their encoding, bump the version tag in\n    // the marker below (for example `bundle-disambiguator-v2`) so\n    // downstream registries receive a deliberate re-shuffle rather\n    // than a silent one.\n    hasher.update(b\"\\x00bundle-disambiguator-v1\\x00\");\n\n    match jstprove_io::bundle::read_bundle_metadata::<jstprove_circuits::api::CircuitParamsType>(\n        bundle_dir,\n    ) {\n        Ok((Some(params), _)) => {\n            hasher.update([1u8]);\n            match params.proof_config {\n                Some(stamped) => {\n                    hasher.update([1u8]);\n                    hasher.update((stamped.config.config_id() as u64).to_le_bytes());\n                    hasher.update(stamped.version.to_le_bytes());\n                }\n                None => hasher.update([0u8]),\n            }\n            hasher.update([u8::from(params.weights_as_inputs)]);\n        }\n        Ok((None, _)) => {\n            hasher.update([0u8]);\n        }\n        Err(e) => {\n            // A malformed or unreadable manifest is meaningfully\n            // different from a bundle that legitimately carries no\n            // metadata. Distinguish the two with separate\n            // discriminator bytes so a corrupt bundle cannot collide\n            // with a clean legacy bundle, and surface the failure in\n            // the tracing log so operators investigating a shifted\n            // sha have the underlying read error to reference.\n            tracing::warn!(\n                bundle = %bundle_dir.display(),\n                error = %e,\n                \"bundle manifest read failed while computing bundle signature; using error discriminator\"\n            );\n            hasher.update([2u8]);\n        }\n    }\n\n    hasher.update([u8::from(jstprove_io::bundle::bundle_has_vk(bundle_dir))]);\n    Ok(format!(\"{:x}\", hasher.finalize()))\n}\n\nfn summarize_onnx_ops(onnx_path: &Path) -> String {\n    let model = match onnx_proto::load_model(onnx_path) {\n        Ok(m) => m,\n        Err(_) => return String::from(\"?\"),\n    };\n    let graph = match model.graph.as_ref() {\n        Some(g) => g,\n        None => return String::from(\"?\"),\n    };\n    let mut counts: std::collections::BTreeMap<&str, usize> = std::collections::BTreeMap::new();\n    for node in &graph.node {\n        *counts.entry(node.op_type.as_str()).or_default() += 1;\n    }\n    counts\n        .iter()\n        .map(|(op, n)| {\n            if *n > 1 {\n                format!(\"{op}x{n}\")\n            } else {\n                op.to_string()\n            }\n        })\n        .collect::<Vec<_>>()\n        .join(\",\")\n}\n\n#[derive(Debug, serde::Serialize)]\npub struct SliceAnalysisReport {\n    pub index: usize,\n    pub backend: String,\n    pub reason: String,\n    pub estimated_constraints: Option<u64>,\n    pub ops: String,\n    pub tiled: bool,\n    pub channel_split: bool,\n    pub dim_split: bool,\n    pub circuit_signature: Option<String>,\n}\n\n/// Derive the three metrics SliceAnalysisReport carries from an\n/// ONNX file: op-summary string, constraint estimate, and curve-\n/// stamped circuit signature.  Used from every analyze_slices\n/// branch that can point at a concrete representative ONNX (the\n/// slice's own .onnx for standard slices, the first channel\n/// group's .onnx for channel-split, the dim-split template\n/// ONNX for dim-split).  Failure on any single metric is\n/// non-fatal: we emit empty / None for the affected field and\n/// continue so analyze never aborts on a partially-materialised\n/// slice.\nfn derive_slice_report_metrics(\n    onnx_path: &Path,\n    proof_config: Option<&str>,\n) -> (String, Option<u64>, Option<String>) {\n    if !onnx_path.exists() {\n        return (String::new(), None, None);\n    }\n    let ops = summarize_onnx_ops(onnx_path);\n    let estimated = estimate_onnx_constraints(onnx_path).ok();\n    let signature = compute_circuit_signature(onnx_path, proof_config).ok();\n    (ops, estimated, signature)\n}\n\npub fn analyze_slices(\n    slices_dir: &Path,\n    jstprove_ops: &[&str],\n    skip_compile_over_size: Option<u64>,\n    proof_config: Option<&str>,\n) -> Result<Vec<SliceAnalysisReport>> {\n    let meta_path = find_metadata_path(slices_dir).ok_or_else(|| {\n        DsperseError::Metadata(format!(\n            \"no {} found in slices directory\",\n            crate::utils::paths::METADATA_FILE\n        ))\n    })?;\n    let mut metadata = ModelMetadata::load(&meta_path)?;\n    // Apply the same split-metadata normalisation compile_slices\n    // performs so the backend / reason classifications below see\n    // populated channel_split.groups, inferred dim_split template\n    // paths, and stripped dim_split entries whose template never\n    // materialised.  Without this step analyze_slices misreports\n    // slices whose split state is implicit in on-disk artefacts.\n    normalize_split_metadata(slices_dir, &meta_path, &mut metadata)?;\n    let mut reports = Vec::with_capacity(metadata.slices.len());\n\n    for slice in &metadata.slices {\n        let slice_dir = slice_dir_path(slices_dir, slice.index);\n        if !slice_dir.exists() {\n            reports.push(SliceAnalysisReport {\n                index: slice.index,\n                backend: \"missing\".into(),\n                reason: \"slice directory not found\".into(),\n                estimated_constraints: None,\n                ops: String::new(),\n                tiled: slice.tiling.is_some(),\n                channel_split: slice.channel_split.is_some(),\n                dim_split: slice.dim_split.is_some(),\n                circuit_signature: None,\n            });\n            continue;\n        }\n\n        if let Some(ref cs) = slice.channel_split\n            && !cs.groups.is_empty()\n        {\n            // Use the first channel-group ONNX as representative\n            // for the reported metrics: every group in the split\n            // shares the same per-chunk topology, so op summary,\n            // constraint estimate, and circuit signature are\n            // group-invariant and the first group is authoritative\n            // for the backend's view of compilation cost.\n            let group_path = slices_dir.join(&cs.groups[0].path);\n            let (ops, estimated, circuit_signature) =\n                derive_slice_report_metrics(&group_path, proof_config);\n            reports.push(SliceAnalysisReport {\n                index: slice.index,\n                backend: \"jstprove\".into(),\n                reason: \"channel-split\".into(),\n                estimated_constraints: estimated,\n                ops,\n                tiled: slice.tiling.is_some(),\n                channel_split: true,\n                dim_split: false,\n                circuit_signature,\n            });\n            continue;\n        }\n\n        if let Some(ref ds) = slice.dim_split\n            && let Some(ref tmpl_rel) = ds.template_path\n        {\n            // The dim-split template is the ONNX the backend\n            // actually compiles (one circuit shared across every\n            // group), so it is the correct source for the\n            // reported ops / constraint estimate / circuit\n            // signature.\n            let tmpl_path = slices_dir.join(tmpl_rel);\n            let (ops, estimated, circuit_signature) =\n                derive_slice_report_metrics(&tmpl_path, proof_config);\n            reports.push(SliceAnalysisReport {\n                index: slice.index,\n                backend: \"jstprove\".into(),\n                reason: \"dim-split\".into(),\n                estimated_constraints: estimated,\n                ops,\n                tiled: slice.tiling.is_some(),\n                channel_split: false,\n                dim_split: true,\n                circuit_signature,\n            });\n            continue;\n        }\n\n        let onnx_path = match resolve_compile_onnx(slices_dir, slice) {\n            Ok(p) => p,\n            Err(_) => {\n                // resolve_compile_onnx failing means the slice has\n                // no ONNX artefact on disk at all.  That is a\n                // genuine \"missing\" state (the analyse footer\n                // already has a dedicated missing count), not an\n                // \"onnx-backend-compatible\" slice.\n                reports.push(SliceAnalysisReport {\n                    index: slice.index,\n                    backend: \"missing\".into(),\n                    reason: \"onnx not found\".into(),\n                    estimated_constraints: None,\n                    ops: String::new(),\n                    tiled: slice.tiling.is_some(),\n                    channel_split: false,\n                    dim_split: false,\n                    circuit_signature: None,\n                });\n                continue;\n            }\n        };\n\n        if !onnx_path.exists() {\n            // Same reasoning as the resolve_compile_onnx Err branch\n            // above: path was resolvable by metadata but the file\n            // is absent, so the slice is missing rather than ONNX-\n            // compatible.\n            reports.push(SliceAnalysisReport {\n                index: slice.index,\n                backend: \"missing\".into(),\n                reason: \"onnx not found\".into(),\n                estimated_constraints: None,\n                ops: String::new(),\n                tiled: slice.tiling.is_some(),\n                channel_split: false,\n                dim_split: false,\n                circuit_signature: None,\n            });\n            continue;\n        }\n\n        let ops = summarize_onnx_ops(&onnx_path);\n        let analysis = analyze_slice_onnx(&onnx_path, jstprove_ops);\n        let estimated = estimate_onnx_constraints(&onnx_path).ok();\n        let sig = compute_circuit_signature(&onnx_path, proof_config).ok();\n\n        let (backend, reason) = match analysis {\n            Ok(a) if !a.compatible => (\"onnx\", \"unsupported ops\"),\n            Ok(a) if a.data_movement_only => (\"onnx\", \"data movement only\"),\n            Ok(_) => {\n                if let (Some(est), Some(thresh)) = (estimated, skip_compile_over_size) {\n                    if est > thresh {\n                        (\"onnx\", \"exceeds size threshold\")\n                    } else {\n                        (\"jstprove\", \"compilable\")\n                    }\n                } else {\n                    (\"jstprove\", \"compilable\")\n                }\n            }\n            Err(_) => (\"onnx\", \"analysis failed\"),\n        };\n\n        reports.push(SliceAnalysisReport {\n            index: slice.index,\n            backend: backend.into(),\n            reason: reason.into(),\n            estimated_constraints: estimated,\n            ops,\n            tiled: slice.tiling.is_some(),\n            channel_split: false,\n            dim_split: slice.dim_split.is_some(),\n            circuit_signature: sig,\n        });\n    }\n\n    Ok(reports)\n}\n\nfn estimate_onnx_constraints(onnx_path: &Path) -> Result<u64> {\n    let model = onnx_proto::load_model(onnx_path)?;\n    let graph = model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Slicer(format!(\"no graph in {}\", onnx_path.display())))?;\n    let shapes = extract_graph_shapes(graph);\n    Ok(estimate_slice_constraints(&graph.node, &shapes))\n}\n\nfn extract_graph_shapes(\n    graph: &onnx_proto::GraphProto,\n) -> std::collections::HashMap<String, Vec<i64>> {\n    let mut shapes = std::collections::HashMap::new();\n\n    let extract_vi_shape = |vi: &onnx_proto::ValueInfoProto| -> Option<(String, Vec<i64>)> {\n        let tp = vi.r#type.as_ref()?;\n        if let Some(onnx_proto::onnx::type_proto::Value::TensorType(ref tt)) = tp.value {\n            let dims: Vec<i64> = tt\n                .shape\n                .as_ref()?\n                .dim\n                .iter()\n                .filter_map(|d| {\n                    if let Some(onnx_proto::onnx::tensor_shape_proto::dimension::Value::DimValue(\n                        v,\n                    )) = d.value\n                    {\n                        Some(v)\n                    } else {\n                        None\n                    }\n                })\n                .collect();\n            if !dims.is_empty() {\n                return Some((vi.name.clone(), dims));\n            }\n        }\n        None\n    };\n\n    for vi in graph\n        .input\n        .iter()\n        .chain(graph.output.iter())\n        .chain(graph.value_info.iter())\n    {\n        if let Some((name, dims)) = extract_vi_shape(vi) {\n            shapes.insert(name, dims);\n        }\n    }\n\n    for init in &graph.initializer {\n        if !init.name.is_empty() && !init.dims.is_empty() {\n            shapes.insert(init.name.clone(), init.dims.clone());\n        }\n    }\n\n    shapes\n}\n\nfn normalize_slice_for_backend(onnx_path: &Path) -> Result<Option<std::path::PathBuf>> {\n    let mut model = onnx_proto::load_model(onnx_path)?;\n    let changes = onnx_proto::normalize_for_circuit_backend(&mut model);\n    if changes == 0 {\n        return Ok(None);\n    }\n    let normalized = onnx_path.with_extension(\"backend.onnx\");\n    onnx_proto::save_model(&model, &normalized)?;\n    Ok(Some(normalized))\n}\n\n#[allow(clippy::too_many_arguments)]\nfn compile_single_slice(\n    slices_dir: &Path,\n    slice: &crate::schema::metadata::SliceMetadata,\n    backend: &JstproveBackend,\n    proof_config: jstprove_circuits::api::ProofConfigType,\n    weights_as_inputs: bool,\n    jstprove_ops: &[&str],\n    exclude_from_wai: &std::collections::HashSet<String>,\n    skip_compile_over_size: Option<u64>,\n    circuit_cache: &CircuitCache,\n    traced_shapes: Option<&std::collections::HashMap<String, Vec<i64>>>,\n    holographic: bool,\n) -> Result<CompileOutcome> {\n    let slice_dir = slice_dir_path(slices_dir, slice.index);\n    if !slice_dir.exists() {\n        return Err(DsperseError::Pipeline(format!(\n            \"slice directory not found: {}\",\n            slice_dir.display()\n        )));\n    }\n\n    if let Some(ref cs) = slice.channel_split\n        && !cs.groups.is_empty()\n    {\n        return compile_channel_split_slice(\n            slices_dir,\n            slice,\n            cs,\n            backend,\n            proof_config,\n            jstprove_ops,\n            exclude_from_wai,\n            skip_compile_over_size,\n            circuit_cache,\n            traced_shapes,\n            holographic,\n        );\n    }\n\n    if let Some(ref ds) = slice.dim_split\n        && let Some(ref tmpl_rel) = ds.template_path\n    {\n        let tmpl_path = slices_dir.join(tmpl_rel);\n        if tmpl_path.exists() {\n            return compile_dim_split_template(\n                slices_dir,\n                slice,\n                &tmpl_path,\n                backend,\n                proof_config,\n                jstprove_ops,\n                exclude_from_wai,\n                skip_compile_over_size,\n                circuit_cache,\n                traced_shapes,\n                holographic,\n            );\n        }\n    }\n\n    let onnx_path = resolve_compile_onnx(slices_dir, slice)?;\n    if !onnx_path.exists() {\n        return Err(DsperseError::Pipeline(format!(\n            \"ONNX model not found for slice {}: {}\",\n            slice.index,\n            onnx_path.display()\n        )));\n    }\n\n    let analysis = analyze_slice_onnx(&onnx_path, jstprove_ops)?;\n    if !analysis.compatible {\n        return Ok(CompileOutcome::Skipped);\n    }\n    if analysis.data_movement_only {\n        tracing::info!(slice = slice.index, \"skipped (data movement only)\");\n        return Ok(CompileOutcome::Skipped);\n    }\n\n    // The threshold gate needs a concrete estimate; the debug\n    // block below can reuse it so we only re-parse the slice ONNX\n    // for constraint counting once per slice.\n    let mut estimated: Option<u64> = None;\n    if let Some(threshold) = skip_compile_over_size {\n        let est = estimate_onnx_constraints(&onnx_path)?;\n        estimated = Some(est);\n        if est > threshold {\n            return Ok(CompileOutcome::SkippedOverSize {\n                estimated: est,\n                threshold,\n            });\n        }\n    }\n\n    let jst_dir = slice_dir.join(\"jstprove\");\n    std::fs::create_dir_all(&jst_dir).map_err(|e| DsperseError::io(e, &jst_dir))?;\n\n    let circuit_path = jst_dir.join(\"circuit.bundle\");\n\n    if circuit_path.is_dir() {\n        match backend.load_params(&circuit_path) {\n            Ok(_) => {\n                tracing::info!(slice = slice.index, \"already compiled, skipping\");\n                if holographic && !jstprove_io::bundle::bundle_has_vk(&circuit_path) {\n                    run_holographic_setup(backend, &circuit_path, slice.index, \"slice\")?;\n                }\n                return Ok(CompileOutcome::Compiled);\n            }\n            Err(e) => {\n                tracing::warn!(slice = slice.index, error = %e, \"cached circuit invalid, recompiling\");\n                std::fs::remove_dir_all(&circuit_path)\n                    .map_err(|e| DsperseError::io(e, &circuit_path))?;\n            }\n        }\n    }\n\n    let effective_wai = weights_as_inputs;\n\n    // The diagnostic bundle re-parses the slice ONNX (once for the\n    // op summary, once for the constraint estimate if we didn't\n    // already gate through it above).  Skip that work when debug\n    // tracing is disabled -- in a release build across hundreds of\n    // slices it adds up.\n    if tracing::enabled!(tracing::Level::DEBUG) {\n        if estimated.is_none() {\n            estimated = estimate_onnx_constraints(&onnx_path).ok();\n        }\n        let op_summary = summarize_onnx_ops(&onnx_path);\n\n        tracing::debug!(\n            slice = slice.index,\n            onnx = %onnx_path.display(),\n            estimated_constraints = ?estimated,\n            weights_as_inputs = effective_wai,\n            ops = %op_summary,\n            tiled = slice.tiling.is_some(),\n            channel_split = slice.channel_split.is_some(),\n            dim_split = slice.dim_split.is_some(),\n            \"compiling slice\"\n        );\n    }\n\n    let compile_onnx = normalize_slice_for_backend(&onnx_path)?;\n\n    let (params, architecture, wandb) = converter::prepare_jstprove_artifacts_filtered(\n        compile_onnx.as_ref().unwrap_or(&onnx_path),\n        effective_wai,\n        exclude_from_wai,\n        traced_shapes,\n    )?;\n\n    std::panic::catch_unwind(|| {\n        backend.compile(&circuit_path, proof_config, params, architecture, wandb)\n    })\n    .map_err(|p| {\n        let msg = p\n            .downcast_ref::<&str>()\n            .copied()\n            .or_else(|| p.downcast_ref::<String>().map(String::as_str))\n            .unwrap_or(\"unknown panic\");\n        DsperseError::Backend(format!(\"jstprove panicked: {msg}\"))\n    })??;\n\n    if holographic {\n        run_holographic_setup(backend, &circuit_path, slice.index, \"slice\")?;\n    }\n\n    Ok(CompileOutcome::Compiled)\n}\n\n/// Result of a [`setup_holographic_for_slices`] invocation. Mirrors\n/// the structure of [`CompileReport`] so callers can surface\n/// per-slice failures with the same handling.\n#[derive(Debug, Default)]\npub struct HolographicSetupReport {\n    pub processed: usize,\n    pub skipped_already_present: usize,\n    pub failed: Vec<(usize, DsperseError)>,\n}\n\nimpl HolographicSetupReport {\n    pub fn ok_if_no_failures(self) -> Result<Self> {\n        if self.failed.is_empty() {\n            Ok(self)\n        } else {\n            Err(DsperseError::Pipeline(format!(\n                \"setup_holographic_for_slices: {} slice bundle(s) failed\",\n                self.failed.len()\n            )))\n        }\n    }\n}\n\n/// Run holographic GKR setup over every compiled bundle under\n/// `slices_dir`. Walks the slice metadata and, for each slice,\n/// processes the conventional bundle paths produced by\n/// [`compile_slices`]: standard (`jstprove/circuit.bundle`),\n/// channel-split (`jstprove/shared/circuit.bundle`), and dim-split\n/// template (`jstprove/dim_split/circuit.bundle`).\n///\n/// Bundles that already carry a `vk.bin` are skipped unless\n/// `overwrite` is set, so this function is idempotent and cheap to\n/// re-run after a partial failure.\npub fn setup_holographic_for_slices(\n    slices_dir: &Path,\n    backend: &JstproveBackend,\n    parallel: usize,\n    overwrite: bool,\n) -> Result<HolographicSetupReport> {\n    let meta_path = find_metadata_path(slices_dir).ok_or_else(|| {\n        DsperseError::Metadata(format!(\n            \"no {} found in slices directory\",\n            crate::utils::paths::METADATA_FILE\n        ))\n    })?;\n    let metadata = ModelMetadata::load(&meta_path)?;\n\n    let mut targets: Vec<(usize, &'static str, PathBuf)> = Vec::new();\n    for slice in &metadata.slices {\n        let slice_dir = slice_dir_path(slices_dir, slice.index);\n        let candidates: [(&'static str, PathBuf); 3] = [\n            (\"slice\", slice_dir.join(\"jstprove\").join(\"circuit.bundle\")),\n            (\n                \"channel-split-shared\",\n                slice_dir\n                    .join(\"jstprove\")\n                    .join(\"shared\")\n                    .join(\"circuit.bundle\"),\n            ),\n            (\n                \"dim-split-template\",\n                slice_dir\n                    .join(\"jstprove\")\n                    .join(\"dim_split\")\n                    .join(\"circuit.bundle\"),\n            ),\n        ];\n        for (kind, path) in candidates {\n            if path.is_dir() {\n                targets.push((slice.index, kind, path));\n            }\n        }\n    }\n\n    tracing::info!(\n        bundles = targets.len(),\n        parallel,\n        overwrite,\n        \"running holographic GKR setup over compiled bundles\"\n    );\n\n    let pool = rayon::ThreadPoolBuilder::new()\n        .num_threads(parallel)\n        .build()\n        .map_err(|e| DsperseError::Pipeline(format!(\"thread pool: {e}\")))?;\n\n    let processed = std::sync::atomic::AtomicUsize::new(0);\n    let skipped = std::sync::atomic::AtomicUsize::new(0);\n    let errors: std::sync::Mutex<Vec<(usize, DsperseError)>> = std::sync::Mutex::new(Vec::new());\n\n    pool.install(|| {\n        targets\n            .par_iter()\n            .for_each(|(slice_idx, kind, bundle_path)| {\n                if !overwrite && jstprove_io::bundle::bundle_has_vk(bundle_path) {\n                    skipped.fetch_add(1, std::sync::atomic::Ordering::Relaxed);\n                    tracing::info!(\n                        slice = *slice_idx,\n                        kind,\n                        path = %bundle_path.display(),\n                        \"vk.bin already present, skipping (pass --overwrite to regenerate)\"\n                    );\n                    return;\n                }\n                match run_holographic_setup(backend, bundle_path, *slice_idx, kind) {\n                    Ok(()) => {\n                        processed.fetch_add(1, std::sync::atomic::Ordering::Relaxed);\n                    }\n                    Err(e) => {\n                        tracing::warn!(slice = *slice_idx, kind, error = %e, \"holographic setup failed\");\n                        errors.lock().unwrap().push((*slice_idx, e));\n                    }\n                }\n            });\n    });\n\n    Ok(HolographicSetupReport {\n        processed: processed.load(std::sync::atomic::Ordering::Relaxed),\n        skipped_already_present: skipped.load(std::sync::atomic::Ordering::Relaxed),\n        failed: errors.into_inner().unwrap(),\n    })\n}\n\nfn run_holographic_setup(\n    backend: &JstproveBackend,\n    circuit_path: &Path,\n    slice_idx: usize,\n    kind: &'static str,\n) -> Result<()> {\n    tracing::info!(\n        slice = slice_idx,\n        kind,\n        path = %circuit_path.display(),\n        \"running holographic GKR setup\"\n    );\n    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {\n        backend.setup_holographic_vk(circuit_path)\n    }))\n    .map_err(|p| {\n        let msg = p\n            .downcast_ref::<&str>()\n            .copied()\n            .or_else(|| p.downcast_ref::<String>().map(String::as_str))\n            .unwrap_or(\"unknown panic\");\n        DsperseError::Backend(format!(\n            \"jstprove panicked during holographic setup on slice {slice_idx} ({kind}): {msg}\"\n        ))\n    })??;\n    tracing::info!(slice = slice_idx, kind, \"holographic vk persisted\");\n    Ok(())\n}\n\nfn populate_channel_split_groups(\n    slices_dir: &Path,\n    slice_idx: usize,\n    cs: &mut crate::schema::tiling::ChannelSplitInfo,\n) -> Result<bool> {\n    let groups_dir = slices_dir\n        .join(format!(\"slice_{slice_idx}\"))\n        .join(\"payload\")\n        .join(\"channel_groups\");\n    if !groups_dir.exists() {\n        return Ok(false);\n    }\n\n    let cpg = cs.channels_per_group;\n    let mut groups = Vec::with_capacity(cs.num_groups);\n    for g in 0..cs.num_groups {\n        let c_start = g.checked_mul(cpg).ok_or_else(|| {\n            DsperseError::Slicer(format!(\"overflow computing c_start for group {g}\"))\n        })?;\n        let c_end = (g + 1)\n            .checked_mul(cpg)\n            .map(|v| v.min(cs.c_in))\n            .ok_or_else(|| {\n                DsperseError::Slicer(format!(\"overflow computing c_end for group {g}\"))\n            })?;\n        let rel_path = format!(\"slice_{slice_idx}/payload/channel_groups/group_{g}.onnx\");\n        let abs_path = slices_dir.join(&rel_path);\n        if !abs_path.exists() {\n            tracing::warn!(\n                slice = slice_idx,\n                group = g,\n                \"expected group ONNX not found, skipping population\"\n            );\n            return Ok(false);\n        }\n        groups.push(crate::schema::tiling::ChannelGroupInfo {\n            group_idx: g,\n            c_start,\n            c_end,\n            path: rel_path,\n            jstprove_circuit_path: None,\n            jstprove_settings_path: None,\n        });\n    }\n\n    let bias_rel = format!(\"slice_{slice_idx}/payload/channel_groups/bias.msgpack\");\n    if slices_dir.join(&bias_rel).exists() {\n        cs.bias_path = Some(bias_rel);\n    }\n\n    tracing::info!(\n        slice = slice_idx,\n        groups = groups.len(),\n        \"populated channel split groups from materialized files\"\n    );\n    cs.groups = groups;\n    Ok(true)\n}\n\n#[allow(clippy::too_many_arguments)]\nfn compile_channel_split_slice(\n    slices_dir: &Path,\n    slice: &crate::schema::metadata::SliceMetadata,\n    cs: &crate::schema::tiling::ChannelSplitInfo,\n    backend: &JstproveBackend,\n    proof_config: jstprove_circuits::api::ProofConfigType,\n    jstprove_ops: &[&str],\n    exclude_from_wai: &std::collections::HashSet<String>,\n    skip_compile_over_size: Option<u64>,\n    circuit_cache: &CircuitCache,\n    traced_shapes: Option<&std::collections::HashMap<String, Vec<i64>>>,\n    holographic: bool,\n) -> Result<CompileOutcome> {\n    let slice_dir = slice_dir_path(slices_dir, slice.index);\n    let jst_dir = slice_dir.join(\"jstprove\");\n    std::fs::create_dir_all(&jst_dir).map_err(|e| DsperseError::io(e, &jst_dir))?;\n\n    let shared_circuit_rel = format!(\"slice_{}/jstprove/shared/circuit.bundle\", slice.index);\n    let shared_circuit_path = jst_dir.join(\"shared\").join(\"circuit.bundle\");\n\n    // Treat an existing shared bundle the same way the standard-\n    // slice path does: try to load it; if load_params rejects it\n    // (version drift, partial write, corruption), drop the stale\n    // directory and fall through to the compile-fresh branch so a\n    // single bad bundle doesn't permanently wedge every slice in\n    // the channel-split group.  The fresh-build code below is\n    // unchanged and will re-populate from the circuit cache or\n    // via backend.compile as appropriate.\n    let mut needs_build = !shared_circuit_path.is_dir();\n    if !needs_build {\n        match backend.load_params(&shared_circuit_path) {\n            Ok(_) => {\n                tracing::info!(\n                    slice = slice.index,\n                    \"shared circuit already compiled, reusing\"\n                );\n            }\n            Err(e) => {\n                tracing::warn!(\n                    slice = slice.index,\n                    error = %e,\n                    \"cached shared circuit invalid, recompiling\"\n                );\n                std::fs::remove_dir_all(&shared_circuit_path)\n                    .map_err(|e| DsperseError::io(e, &shared_circuit_path))?;\n                needs_build = true;\n            }\n        }\n    }\n\n    if needs_build {\n        let first_group = cs.groups.first().ok_or_else(|| {\n            DsperseError::Pipeline(format!(\"slice {} channel_split has no groups\", slice.index))\n        })?;\n        let onnx_path = slices_dir.join(&first_group.path);\n        if !onnx_path.exists() {\n            return Err(DsperseError::Pipeline(format!(\n                \"channel group ONNX not found: {}\",\n                onnx_path.display()\n            )));\n        }\n\n        let analysis = analyze_slice_onnx(&onnx_path, jstprove_ops)?;\n        if !analysis.compatible {\n            return Err(DsperseError::Pipeline(format!(\n                \"slice {} group 0 has unsupported ops for circuit compilation\",\n                slice.index\n            )));\n        }\n\n        if let Some(threshold) = skip_compile_over_size {\n            let estimated = estimate_onnx_constraints(&onnx_path)?;\n            if estimated > threshold {\n                return Ok(CompileOutcome::SkippedOverSize {\n                    estimated,\n                    threshold,\n                });\n            }\n        }\n\n        let sig = compute_circuit_signature(&onnx_path, None)?;\n\n        let cached = circuit_cache.lock().unwrap().get(&sig).cloned();\n        if let Some(ref cached_path) = cached\n            && cached_path.is_dir()\n        {\n            let shared_dir = shared_circuit_path.parent().ok_or_else(|| {\n                DsperseError::Pipeline(\"shared circuit path has no parent\".into())\n            })?;\n            std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?;\n            copy_dir_recursive(cached_path, &shared_circuit_path)?;\n            tracing::info!(\n                slice = slice.index,\n                sig = %sig,\n                \"reused cached channel-split circuit from prior slice\"\n            );\n        } else {\n            let shared_dir = shared_circuit_path.parent().ok_or_else(|| {\n                DsperseError::Pipeline(\"shared circuit path has no parent\".into())\n            })?;\n            std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?;\n\n            tracing::info!(\n                slice = slice.index,\n                groups = cs.groups.len(),\n                sig = %sig,\n                \"compiling shared channel group circuit (weights-as-inputs)\"\n            );\n\n            let (params, architecture, wandb) = converter::prepare_jstprove_artifacts_filtered(\n                &onnx_path,\n                true,\n                exclude_from_wai,\n                traced_shapes,\n            )?;\n\n            std::panic::catch_unwind(|| {\n                backend.compile(\n                    &shared_circuit_path,\n                    proof_config,\n                    params,\n                    architecture,\n                    wandb,\n                )\n            })\n            .map_err(|p| {\n                let msg = p\n                    .downcast_ref::<&str>()\n                    .copied()\n                    .or_else(|| p.downcast_ref::<String>().map(String::as_str))\n                    .unwrap_or(\"unknown panic\");\n                DsperseError::Backend(format!(\n                    \"jstprove panicked on slice {} shared circuit: {msg}\",\n                    slice.index\n                ))\n            })??;\n\n            circuit_cache\n                .lock()\n                .unwrap()\n                .insert(sig.clone(), shared_circuit_path.clone());\n            tracing::info!(slice = slice.index, sig = %sig, \"shared circuit compiled\");\n        }\n\n        // One final load to match the cached-bundle branch's\n        // invariant: the function returns only after we have seen\n        // a viable shared circuit at shared_circuit_path.  If the\n        // freshly-built bundle still fails to load, a retry would\n        // recurse indefinitely, so surface the error.\n        backend.load_params(&shared_circuit_path).map_err(|e| {\n            DsperseError::Pipeline(format!(\n                \"slice {} freshly-built shared circuit failed to load: {e}\",\n                slice.index\n            ))\n        })?;\n\n        if holographic && !jstprove_io::bundle::bundle_has_vk(&shared_circuit_path) {\n            // When the needs_build branch took the memcache-reuse\n            // sub-path, copy_dir_recursive may already have brought\n            // vk.bin across from the source bundle; skip the\n            // expensive re-setup in that case and only run it when\n            // the shared bundle genuinely lacks a vk (the\n            // fresh-compile sub-path, or a source that raced us\n            // before its own setup persisted).\n            run_holographic_setup(\n                backend,\n                &shared_circuit_path,\n                slice.index,\n                \"channel-split-shared\",\n            )?;\n        }\n    } else if holographic && !jstprove_io::bundle::bundle_has_vk(&shared_circuit_path) {\n        // Cached bundle predates the holographic plumbing: backfill\n        // the vk so reused circuits stay in sync with freshly-built\n        // ones.\n        run_holographic_setup(\n            backend,\n            &shared_circuit_path,\n            slice.index,\n            \"channel-split-shared\",\n        )?;\n    }\n\n    let group_circuits: Vec<(usize, String)> = cs\n        .groups\n        .iter()\n        .map(|g| (g.group_idx, shared_circuit_rel.clone()))\n        .collect();\n\n    Ok(CompileOutcome::CompiledChannelSplit { group_circuits })\n}\n\n#[allow(clippy::too_many_arguments)]\nfn compile_dim_split_template(\n    slices_dir: &Path,\n    slice: &crate::schema::metadata::SliceMetadata,\n    tmpl_path: &Path,\n    backend: &JstproveBackend,\n    proof_config: jstprove_circuits::api::ProofConfigType,\n    jstprove_ops: &[&str],\n    exclude_from_wai: &std::collections::HashSet<String>,\n    skip_compile_over_size: Option<u64>,\n    circuit_cache: &CircuitCache,\n    _traced_shapes: Option<&std::collections::HashMap<String, Vec<i64>>>,\n    holographic: bool,\n) -> Result<CompileOutcome> {\n    let slice_dir = slice_dir_path(slices_dir, slice.index);\n    let jst_dir = slice_dir.join(\"jstprove\");\n    std::fs::create_dir_all(&jst_dir).map_err(|e| DsperseError::io(e, &jst_dir))?;\n\n    let circuit_path = jst_dir.join(\"dim_split\").join(\"circuit.bundle\");\n\n    if circuit_path.is_dir() {\n        match backend.load_params(&circuit_path) {\n            Ok(_) => {\n                tracing::info!(\n                    slice = slice.index,\n                    \"dim-split template already compiled, reusing\"\n                );\n                if holographic && !jstprove_io::bundle::bundle_has_vk(&circuit_path) {\n                    // Backfill vk on cached bundles; see channel-\n                    // split branch above for the same rationale.\n                    run_holographic_setup(\n                        backend,\n                        &circuit_path,\n                        slice.index,\n                        \"dim-split-template\",\n                    )?;\n                }\n                return Ok(CompileOutcome::CompiledDimSplit);\n            }\n            Err(e) => {\n                tracing::warn!(slice = slice.index, error = %e, \"cached dim-split circuit invalid, recompiling\");\n                std::fs::remove_dir_all(&circuit_path)\n                    .map_err(|e| DsperseError::io(e, &circuit_path))?;\n            }\n        }\n    }\n\n    let analysis = analyze_slice_onnx(tmpl_path, jstprove_ops)?;\n    if !analysis.compatible {\n        return Ok(CompileOutcome::Skipped);\n    }\n\n    if let Some(threshold) = skip_compile_over_size {\n        let estimated = slice\n            .dim_split\n            .as_ref()\n            .map(|ds| ds.estimated_group_constraints)\n            .filter(|&e| e > 0)\n            .or_else(|| match estimate_onnx_constraints(tmpl_path) {\n                Ok(e) => Some(e),\n                Err(err) => {\n                    // We can't turn an unknown cost into a safe\n                    // gating decision, so fall through and let the\n                    // compile attempt surface the real error rather\n                    // than silently treating the slice as tiny.\n                    tracing::warn!(\n                        slice = slice.index,\n                        onnx = %tmpl_path.display(),\n                        error = %err,\n                        \"skip_compile_over_size: constraint estimate failed; proceeding to compile\"\n                    );\n                    None\n                }\n            });\n        if let Some(estimated) = estimated\n            && estimated > threshold\n        {\n            return Ok(CompileOutcome::SkippedOverSize {\n                estimated,\n                threshold,\n            });\n        }\n    }\n\n    let sig = compute_circuit_signature(tmpl_path, None)?;\n\n    let cached = circuit_cache.lock().unwrap().get(&sig).cloned();\n    if let Some(ref cached_path) = cached\n        && cached_path.is_dir()\n    {\n        let shared_dir = circuit_path\n            .parent()\n            .ok_or_else(|| DsperseError::Pipeline(\"dim-split circuit path has no parent\".into()))?;\n        std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?;\n        copy_dir_recursive(cached_path, &circuit_path)?;\n        tracing::info!(\n            slice = slice.index,\n            sig = %sig,\n            \"reused cached dim-split circuit\"\n        );\n        // circuit_cache can hand back a source bundle that was\n        // inserted before its own run_holographic_setup finished\n        // (the fresh-build branch below inserts the sig before it\n        // persists vk.bin), so a parallel racer can snapshot a\n        // pre-vk source and copy_dir_recursive a bundle missing\n        // vk.bin.  Mirror the channel-split reuse branch and\n        // backfill on the copy so every reused dim-split bundle\n        // ends up in the same shape as a freshly-compiled one.\n        if holographic && !jstprove_io::bundle::bundle_has_vk(&circuit_path) {\n            run_holographic_setup(backend, &circuit_path, slice.index, \"dim-split-template\")?;\n        }\n        return Ok(CompileOutcome::CompiledDimSplit);\n    }\n\n    let shared_dir = circuit_path\n        .parent()\n        .ok_or_else(|| DsperseError::Pipeline(\"dim-split circuit path has no parent\".into()))?;\n    std::fs::create_dir_all(shared_dir).map_err(|e| DsperseError::io(e, shared_dir))?;\n\n    tracing::info!(\n        slice = slice.index,\n        sig = %sig,\n        \"compiling dim-split template (weights-as-inputs)\"\n    );\n\n    // Do NOT pass the original traced_shapes when compiling dim-split\n    // templates. The template has rewritten shapes (dim_size → epg) that\n    // differ from the original model's traced shapes. If traced_shapes\n    // is passed, jstprove uses the original (larger) shapes and the\n    // Transpose/Reshape validation fails on the mismatch.\n    let (params, architecture, wandb) =\n        converter::prepare_jstprove_artifacts_filtered(tmpl_path, true, exclude_from_wai, None)?;\n\n    std::panic::catch_unwind(|| {\n        backend.compile(&circuit_path, proof_config, params, architecture, wandb)\n    })\n    .map_err(|p| {\n        let msg = p\n            .downcast_ref::<&str>()\n            .copied()\n            .or_else(|| p.downcast_ref::<String>().map(String::as_str))\n            .unwrap_or(\"unknown panic\");\n        DsperseError::Backend(format!(\n            \"jstprove panicked on slice {} dim-split template: {msg}\",\n            slice.index\n        ))\n    })??;\n\n    circuit_cache\n        .lock()\n        .unwrap()\n        .insert(sig.clone(), circuit_path.clone());\n    tracing::info!(slice = slice.index, sig = %sig, \"dim-split template compiled\");\n\n    if holographic {\n        run_holographic_setup(backend, &circuit_path, slice.index, \"dim-split-template\")?;\n    }\n\n    Ok(CompileOutcome::CompiledDimSplit)\n}\n\nfn copy_dir_recursive(src: &Path, dst: &Path) -> Result<()> {\n    std::fs::create_dir_all(dst).map_err(|e| DsperseError::io(e, dst))?;\n    for entry in std::fs::read_dir(src).map_err(|e| DsperseError::io(e, src))? {\n        let entry = entry.map_err(|e| DsperseError::io(e, src))?;\n        let ty = entry.file_type().map_err(|e| DsperseError::io(e, src))?;\n        let dst_path = dst.join(entry.file_name());\n        if ty.is_dir() {\n            copy_dir_recursive(&entry.path(), &dst_path)?;\n        } else {\n            std::fs::copy(entry.path(), &dst_path).map_err(|e| DsperseError::io(e, &dst_path))?;\n        }\n    }\n    Ok(())\n}\n\nfn resolve_compile_onnx(\n    slices_dir: &Path,\n    slice: &crate::schema::metadata::SliceMetadata,\n) -> Result<std::path::PathBuf> {\n    if let Some(ref tiling) = slice.tiling\n        && let Some(ref tile) = tiling.tile\n    {\n        let tile_path = slices_dir.join(&tile.path);\n        if tile_path.exists() {\n            tracing::info!(\n                slice = slice.index,\n                path = %tile_path.display(),\n                \"using tile ONNX\"\n            );\n            return Ok(tile_path);\n        }\n    }\n\n    slice.resolve_onnx(slices_dir)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::schema::metadata::{\n        Compilation, Dependencies, SliceMetadata, SliceShapeWrapper, TensorShape,\n    };\n    use crate::schema::tiling::{TileInfo, TilingInfo};\n\n    fn test_models_dir() -> std::path::PathBuf {\n        std::path::PathBuf::from(concat!(env!(\"CARGO_MANIFEST_DIR\"), \"/../../tests/models\"))\n    }\n\n    fn make_slice_metadata(index: usize, path: &str) -> SliceMetadata {\n        SliceMetadata {\n            index,\n            filename: format!(\"slice_{index}.onnx\"),\n            path: path.to_string(),\n            relative_path: path.to_string(),\n            shape: SliceShapeWrapper {\n                tensor_shape: TensorShape::default(),\n            },\n            dependencies: Dependencies {\n                input: vec![],\n                output: vec![],\n                filtered_inputs: vec![],\n            },\n            tiling: None,\n            channel_split: None,\n            dim_split: None,\n            compilation: Compilation::default(),\n            slice_metadata: None,\n            slice_metadata_relative_path: None,\n        }\n    }\n\n    const TEST_OPS: &[&str] = &[\"Conv\", \"Gemm\", \"MatMul\"];\n\n    #[test]\n    fn analyze_slice_onnx_nonexistent() {\n        let result = analyze_slice_onnx(Path::new(\"/nonexistent.onnx\"), TEST_OPS);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn analyze_slice_onnx_test_model() {\n        let model_path = test_models_dir().join(\"net/model.onnx\");\n        assert!(\n            model_path.exists(),\n            \"fixture missing: {}\",\n            model_path.display()\n        );\n        let analysis = analyze_slice_onnx(&model_path, TEST_OPS).unwrap();\n        assert!(!analysis.compatible);\n    }\n\n    #[test]\n    fn analyze_slice_onnx_with_initializers() {\n        let tmp = tempfile::tempdir().unwrap();\n        let path = tmp.path().join(\"with_init.onnx\");\n        let model = onnx_proto::ModelProto {\n            graph: Some(onnx_proto::GraphProto {\n                node: vec![onnx_proto::make_node(\"Conv\", vec![], vec![], vec![])],\n                initializer: vec![onnx_proto::make_tensor(\n                    \"weight\",\n                    1,\n                    &[3, 3, 3, 3],\n                    vec![0.0; 81],\n                )],\n                ..Default::default()\n            }),\n            ..Default::default()\n        };\n        onnx_proto::save_model(&model, &path).unwrap();\n        let analysis = analyze_slice_onnx(&path, &[\"Conv\"]).unwrap();\n        assert!(analysis.compatible);\n    }\n\n    #[test]\n    fn analyze_slice_onnx_without_initializers() {\n        let tmp = tempfile::tempdir().unwrap();\n        let path = tmp.path().join(\"no_init.onnx\");\n        let model = onnx_proto::ModelProto {\n            graph: Some(onnx_proto::GraphProto {\n                node: vec![onnx_proto::make_node(\"Relu\", vec![], vec![], vec![])],\n                initializer: vec![],\n                ..Default::default()\n            }),\n            ..Default::default()\n        };\n        onnx_proto::save_model(&model, &path).unwrap();\n        let analysis = analyze_slice_onnx(&path, &[\"Relu\"]).unwrap();\n        assert!(analysis.compatible);\n    }\n\n    #[test]\n    fn resolve_compile_onnx_no_tiling() {\n        let tmp = tempfile::tempdir().unwrap();\n        let slices_dir = tmp.path();\n        let slice_dir = slices_dir.join(\"slice_0\");\n        std::fs::create_dir_all(&slice_dir).unwrap();\n\n        let meta = make_slice_metadata(0, \"slice_0.onnx\");\n        let path = resolve_compile_onnx(slices_dir, &meta).unwrap();\n        assert!(path.ends_with(\"slice_0.onnx\"));\n    }\n\n    #[test]\n    fn resolve_compile_onnx_with_tile() {\n        let tmp = tempfile::tempdir().unwrap();\n        let slices_dir = tmp.path();\n        let tile_path = slices_dir.join(\"slice_0/payload/tiles/tile.onnx\");\n        std::fs::create_dir_all(tile_path.parent().unwrap()).unwrap();\n        std::fs::write(&tile_path, b\"dummy\").unwrap();\n\n        let mut meta = make_slice_metadata(0, \"slice_0.onnx\");\n        meta.tiling = Some(TilingInfo {\n            slice_idx: 0,\n            tile_size: 8,\n            num_tiles: 4,\n            tiles_y: 2,\n            tiles_x: 2,\n            halo: [1, 1, 1, 1],\n            out_tile: [4, 4],\n            stride: [1, 1],\n            c_in: 3,\n            c_out: 16,\n            input_name: \"input\".into(),\n            output_name: \"output\".into(),\n            input_names: vec![],\n            ndim: 4,\n            h: 16,\n            w: 16,\n            tile: Some(TileInfo {\n                path: \"slice_0/payload/tiles/tile.onnx\".into(),\n                conv_out: [4, 4],\n                jstprove_circuit_path: None,\n            }),\n            tiles: None,\n            segment_size: None,\n            total_elements: None,\n            original_shape: vec![],\n        });\n        let path = resolve_compile_onnx(slices_dir, &meta).unwrap();\n        assert!(path.ends_with(\"tile.onnx\"));\n    }\n\n    #[test]\n    fn resolve_compile_onnx_tile_missing_falls_back() {\n        let tmp = tempfile::tempdir().unwrap();\n        let slices_dir = tmp.path();\n        let slice_dir = slices_dir.join(\"slice_0\");\n        std::fs::create_dir_all(&slice_dir).unwrap();\n\n        let mut meta = make_slice_metadata(0, \"slice_0.onnx\");\n        meta.tiling = Some(TilingInfo {\n            slice_idx: 0,\n            tile_size: 8,\n            num_tiles: 4,\n            tiles_y: 2,\n            tiles_x: 2,\n            halo: [1, 1, 1, 1],\n            out_tile: [4, 4],\n            stride: [1, 1],\n            c_in: 3,\n            c_out: 16,\n            input_name: \"input\".into(),\n            output_name: \"output\".into(),\n            input_names: vec![],\n            ndim: 4,\n            h: 16,\n            w: 16,\n            tile: Some(TileInfo {\n                path: \"slice_0/payload/tiles/nonexistent.onnx\".into(),\n                conv_out: [4, 4],\n                jstprove_circuit_path: None,\n            }),\n            tiles: None,\n            segment_size: None,\n            total_elements: None,\n            original_shape: vec![],\n        });\n        let path = resolve_compile_onnx(slices_dir, &meta).unwrap();\n        assert!(path.ends_with(\"slice_0.onnx\"));\n    }\n\n    fn write_identity_onnx(path: &Path) {\n        let node = onnx_proto::NodeProto {\n            op_type: \"Relu\".to_string(),\n            input: vec![\"x\".to_string()],\n            output: vec![\"y\".to_string()],\n            ..Default::default()\n        };\n        let graph = onnx_proto::make_graph(\n            \"g\",\n            vec![node],\n            vec![onnx_proto::make_tensor_value_info(\"x\", 1, &[1, 8])],\n            vec![onnx_proto::make_tensor_value_info(\"y\", 1, &[1, 8])],\n            vec![],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n        onnx_proto::save_model(&model, path).unwrap();\n    }\n\n    #[test]\n    fn bundle_signature_differs_from_circuit_signature_even_without_metadata() {\n        let tmp = tempfile::tempdir().unwrap();\n        let onnx_path = tmp.path().join(\"slice.onnx\");\n        write_identity_onnx(&onnx_path);\n        let bundle_dir = tmp.path().join(\"bundle\");\n        std::fs::create_dir_all(&bundle_dir).unwrap();\n\n        let base = compute_circuit_signature(&onnx_path, None).unwrap();\n        let bundle_sig = compute_bundle_signature(&onnx_path, None, &bundle_dir).unwrap();\n\n        assert_ne!(\n            base, bundle_sig,\n            \"bundle signature must always include discriminator bytes\"\n        );\n\n        let bundle_sig_again = compute_bundle_signature(&onnx_path, None, &bundle_dir).unwrap();\n        assert_eq!(\n            bundle_sig, bundle_sig_again,\n            \"bundle signature must be deterministic\"\n        );\n    }\n\n    #[test]\n    fn bundle_signature_disambiguates_vk_presence() {\n        let tmp = tempfile::tempdir().unwrap();\n        let onnx_path = tmp.path().join(\"slice.onnx\");\n        write_identity_onnx(&onnx_path);\n\n        let plain_bundle = tmp.path().join(\"plain\");\n        let holo_bundle = tmp.path().join(\"holographic\");\n        std::fs::create_dir_all(&plain_bundle).unwrap();\n        std::fs::create_dir_all(&holo_bundle).unwrap();\n        std::fs::write(holo_bundle.join(\"vk.bin\"), b\"vk-contents\").unwrap();\n\n        let plain_sig = compute_bundle_signature(&onnx_path, None, &plain_bundle).unwrap();\n        let holo_sig = compute_bundle_signature(&onnx_path, None, &holo_bundle).unwrap();\n\n        assert_ne!(\n            plain_sig, holo_sig,\n            \"holographic bundle must produce a distinct signature\"\n        );\n    }\n\n    #[test]\n    fn bundle_signature_disambiguates_proof_config_and_wai_on_metadata_branch() {\n        use std::collections::HashMap;\n\n        use jstprove_circuits::ProofSystem;\n        use jstprove_circuits::api::{CircuitParamsType, ProofConfigType, StampedProofConfigType};\n        use jstprove_io::bundle::write_bundle;\n\n        let tmp = tempfile::tempdir().unwrap();\n        let onnx_path = tmp.path().join(\"slice.onnx\");\n        write_identity_onnx(&onnx_path);\n\n        fn make_params(config: ProofConfigType, weights_as_inputs: bool) -> CircuitParamsType {\n            CircuitParamsType {\n                scale_base: 2,\n                scale_exponent: 8,\n                rescale_config: HashMap::new(),\n                inputs: Vec::new(),\n                outputs: Vec::new(),\n                freivalds_reps: 1,\n                n_bits_config: HashMap::new(),\n                weights_as_inputs,\n                proof_system: ProofSystem::default(),\n                proof_config: Some(StampedProofConfigType::current(config)),\n                logup_chunk_bits: None,\n                public_inputs: Vec::new(),\n            }\n        }\n\n        let bn254_bundle = tmp.path().join(\"bn254\");\n        let goldi_bundle = tmp.path().join(\"goldilocks\");\n        let bn254_wai_bundle = tmp.path().join(\"bn254-wai\");\n\n        write_bundle(\n            &bn254_bundle,\n            &[1, 2, 3],\n            &[4, 5, 6],\n            Some(make_params(ProofConfigType::Bn254Raw, false)),\n            None,\n            false,\n        )\n        .unwrap();\n        write_bundle(\n            &goldi_bundle,\n            &[1, 2, 3],\n            &[4, 5, 6],\n            Some(make_params(ProofConfigType::GoldilocksExt4Whir, false)),\n            None,\n            false,\n        )\n        .unwrap();\n        write_bundle(\n            &bn254_wai_bundle,\n            &[1, 2, 3],\n            &[4, 5, 6],\n            Some(make_params(ProofConfigType::Bn254Raw, true)),\n            None,\n            false,\n        )\n        .unwrap();\n\n        let sig_bn254 = compute_bundle_signature(&onnx_path, None, &bn254_bundle).unwrap();\n        let sig_goldi = compute_bundle_signature(&onnx_path, None, &goldi_bundle).unwrap();\n        let sig_bn254_wai = compute_bundle_signature(&onnx_path, None, &bn254_wai_bundle).unwrap();\n\n        assert_ne!(\n            sig_bn254, sig_goldi,\n            \"config_id must discriminate bundles with different ProofConfig variants\"\n        );\n        assert_ne!(\n            sig_bn254, sig_bn254_wai,\n            \"weights_as_inputs must discriminate bundles with the same ProofConfig\"\n        );\n\n        let sig_bn254_again = compute_bundle_signature(&onnx_path, None, &bn254_bundle).unwrap();\n        assert_eq!(\n            sig_bn254, sig_bn254_again,\n            \"signature must be deterministic for the metadata branch\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/dim_split.rs",
    "content": "use std::collections::HashMap;\nuse std::path::Path;\n\nuse super::runner::{run_onnx_inference, run_onnx_inference_multi_named};\nuse super::tensor_store::TensorStore;\nuse crate::backend::jstprove::JstproveBackend;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::execution::ExecutionInfo;\nuse crate::schema::tiling::DimSplitKind;\nuse crate::slicer::onnx_proto::TensorProto;\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn execute_dim_split(\n    slices_dir: &Path,\n    _slice_run_dir: &Path,\n    slice_id: &str,\n    ds: &crate::schema::tiling::DimSplitInfo,\n    target_shape: Option<&[i64]>,\n    tensor_cache: &TensorStore,\n    _backend: &JstproveBackend,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<crate::schema::execution::StrategyOutput> {\n    let tmpl_rel = ds.template_path.as_ref().ok_or_else(|| {\n        DsperseError::Pipeline(format!(\"{slice_id}: dim_split has no template_path\"))\n    })?;\n    let tmpl_path = slices_dir.join(tmpl_rel);\n    if !tmpl_path.exists() {\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: dim-split template not found: {}\",\n            tmpl_path.display()\n        )));\n    }\n\n    let use_matmul_split = matches!(ds.split_kind, DimSplitKind::MatMulOutputDim);\n\n    let final_result = if use_matmul_split {\n        execute_matmul_dim_split(\n            slices_dir,\n            slice_id,\n            ds,\n            target_shape,\n            tensor_cache,\n            &tmpl_path,\n            donor_init_map,\n        )?\n    } else {\n        execute_generic_dim_split(slice_id, ds, target_shape, tensor_cache, &tmpl_path)?\n    };\n\n    Ok(crate::schema::execution::StrategyOutput {\n        info: ExecutionInfo {\n            method: crate::schema::execution::ExecutionMethod::DimSplit,\n            success: true,\n            error: None,\n            witness_file: None,\n            tile_exec_infos: Vec::new(),\n        },\n        outputs: vec![(ds.output_name.clone(), final_result)],\n    })\n}\n\n#[allow(clippy::too_many_arguments)]\nfn execute_matmul_dim_split(\n    slices_dir: &Path,\n    slice_id: &str,\n    ds: &crate::schema::tiling::DimSplitInfo,\n    target_shape: Option<&[i64]>,\n    tensor_cache: &TensorStore,\n    tmpl_path: &Path,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<ndarray::ArrayD<f64>> {\n    let input_tensor = tensor_cache.get(&ds.input_name)?.clone();\n    let input_shape = input_tensor.shape().to_vec();\n    let k_dim = *input_shape.last().unwrap_or(&0);\n    if ds.k_dim != 0 && k_dim != ds.k_dim {\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: runtime k_dim {} from input {:?} does not match metadata k_dim {}\",\n            k_dim, ds.input_name, ds.k_dim\n        )));\n    }\n    if k_dim == 0 {\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: dim-split input {:?} has zero-width last dim; expected k_dim > 0\",\n            ds.input_name\n        )));\n    }\n    let k_chunks = ds.k_chunks.max(1);\n    let k_chunk_size = k_dim.div_ceil(k_chunks);\n\n    let total_rows: usize = input_shape\n        .iter()\n        .take(input_shape.len().saturating_sub(1))\n        .product();\n    let flat_input = input_tensor\n        .as_standard_layout()\n        .into_owned()\n        .into_shape_with_order(ndarray::IxDyn(&[total_rows, k_dim]))\n        .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: flatten input: {e}\")))?;\n\n    let slice_onnx_path = slices_dir\n        .join(format!(\"slice_{}\", ds.slice_idx))\n        .join(\"payload\")\n        .join(format!(\"slice_{}.onnx\", ds.slice_idx));\n\n    let orig_model = crate::slicer::onnx_proto::load_model(&slice_onnx_path)?;\n    let orig_graph = orig_model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Pipeline(format!(\"{slice_id}: slice ONNX has no graph\")))?;\n    let weight_name = ds.weight_name.as_ref().ok_or_else(|| {\n        DsperseError::Pipeline(format!(\n            \"{slice_id}: dim_split missing weight_name in metadata\"\n        ))\n    })?;\n    let matmul_node = orig_graph\n        .node\n        .iter()\n        .find(|n| {\n            matches!(n.op_type.as_str(), \"MatMul\" | \"Gemm\")\n                && n.input.iter().any(|i| i == weight_name)\n                && n.input.iter().any(|i| i == &ds.input_name)\n                && n.output.iter().any(|o| o == &ds.output_name)\n        })\n        .ok_or_else(|| {\n            DsperseError::Pipeline(format!(\n                \"{slice_id}: no MatMul/Gemm node matches weight={weight_name:?} input={:?} output={:?}\",\n                ds.input_name, ds.output_name\n            ))\n        })?;\n    let trans_b = matmul_node.op_type == \"Gemm\"\n        && crate::slicer::onnx_proto::get_attribute_int(matmul_node, \"transB\").unwrap_or(0) == 1;\n    let full_weight: Vec<f32> = if let Some(map) = donor_init_map\n        && let Some(t) = map.get(weight_name.as_str())\n    {\n        crate::slicer::onnx_proto::tensor_to_f32(t)\n    } else {\n        let init = orig_graph\n            .initializer\n            .iter()\n            .find(|i| i.name == *weight_name)\n            .ok_or_else(|| {\n                DsperseError::Pipeline(format!(\n                    \"{slice_id}: weight {weight_name:?} not found in slice ONNX initializers\"\n                ))\n            })?;\n        crate::slicer::onnx_proto::tensor_to_f32(init)\n    };\n    let expected_weight_len = ds.k_dim.saturating_mul(ds.n_dim);\n    if expected_weight_len > 0 && full_weight.len() != expected_weight_len {\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: weight {weight_name:?} length {} does not match expected k_dim*n_dim = {}*{} = {}\",\n            full_weight.len(),\n            ds.k_dim,\n            ds.n_dim,\n            expected_weight_len\n        )));\n    }\n\n    let n_dim = ds.n_dim;\n    let tmpl_model = crate::slicer::onnx_proto::load_model(tmpl_path)?;\n\n    let tmp_dir = tempfile::tempdir()\n        .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: tmpdir: {e}\")))?;\n\n    let mut patched_paths: Vec<std::path::PathBuf> = Vec::with_capacity(k_chunks);\n    for kc in 0..k_chunks {\n        let k_start = kc * k_chunk_size;\n        let k_end = (k_start + k_chunk_size).min(k_dim);\n        let actual_k = k_end.saturating_sub(k_start);\n\n        let weight_chunk: Vec<f32> = if trans_b {\n            let mut w = Vec::with_capacity(n_dim * k_chunk_size);\n            for row_idx in 0..n_dim {\n                let row_start = row_idx * k_dim + k_start;\n                let avail = actual_k.min(full_weight.len().saturating_sub(row_start));\n                w.extend_from_slice(&full_weight[row_start..row_start + avail]);\n                if avail < k_chunk_size {\n                    w.resize(w.len() + k_chunk_size - avail, 0.0);\n                }\n            }\n            w\n        } else {\n            let mut w = Vec::with_capacity(k_chunk_size * n_dim);\n            for ki in k_start..k_start + actual_k {\n                let start = ki * n_dim;\n                let end = start + n_dim;\n                if end <= full_weight.len() {\n                    w.extend_from_slice(&full_weight[start..end]);\n                } else {\n                    w.resize(w.len() + n_dim, 0.0);\n                }\n            }\n            if actual_k < k_chunk_size {\n                w.resize(k_chunk_size * n_dim, 0.0);\n            }\n            w\n        };\n\n        let mut patched = tmpl_model.clone();\n        let graph = patched.graph.as_mut().ok_or_else(|| {\n            DsperseError::Pipeline(format!(\n                \"{slice_id}: dim-split template at {} has no graph\",\n                tmpl_path.display()\n            ))\n        })?;\n        let w_init = graph\n            .initializer\n            .iter_mut()\n            .find(|i| i.name == \"W\")\n            .ok_or_else(|| {\n                DsperseError::Pipeline(format!(\n                    \"{slice_id}: dim-split template at {} missing 'W' initializer\",\n                    tmpl_path.display()\n                ))\n            })?;\n        w_init.float_data = weight_chunk;\n        w_init.raw_data.clear();\n\n        let patched_path = tmp_dir.path().join(format!(\"chunk_{kc}.onnx\"));\n        crate::slicer::onnx_proto::save_model(&patched, &patched_path)?;\n        patched_paths.push(patched_path);\n    }\n\n    let mut row_outputs: Vec<ndarray::ArrayD<f64>> = Vec::with_capacity(total_rows);\n\n    for r in 0..total_rows {\n        let full_row: Vec<f64> = flat_input\n            .slice(ndarray::s![r, ..])\n            .iter()\n            .copied()\n            .collect();\n\n        let mut row_accum = vec![0.0f64; n_dim];\n\n        for (kc, patched_path) in patched_paths.iter().enumerate() {\n            let k_start = kc * k_chunk_size;\n            let k_end = (k_start + k_chunk_size).min(k_dim);\n            let actual_k = k_end.saturating_sub(k_start);\n\n            let mut input_chunk = vec![0.0f64; k_chunk_size];\n            if actual_k > 0 {\n                input_chunk[..actual_k].copy_from_slice(&full_row[k_start..k_end]);\n            }\n\n            let input_arr =\n                ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[1, k_chunk_size]), input_chunk)\n                    .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: input chunk: {e}\")))?;\n\n            let out = run_onnx_inference(patched_path, &input_arr)?;\n            if out.len() != n_dim {\n                return Err(DsperseError::Pipeline(format!(\n                    \"{slice_id}: dim-split k-chunk {kc} produced {} outputs, expected n_dim={n_dim}\",\n                    out.len()\n                )));\n            }\n            for (acc, v) in row_accum.iter_mut().zip(out.iter().copied()) {\n                *acc += v;\n            }\n        }\n\n        let row_arr = ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[1, n_dim]), row_accum)\n            .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: row output: {e}\")))?;\n        row_outputs.push(row_arr);\n    }\n\n    let stacked: ndarray::ArrayD<f64> = if row_outputs.is_empty() {\n        ndarray::ArrayD::zeros(ndarray::IxDyn(&[0, n_dim]))\n    } else {\n        ndarray::concatenate(\n            ndarray::Axis(0),\n            &row_outputs.iter().map(|a| a.view()).collect::<Vec<_>>(),\n        )\n        .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: row concat: {e}\")))?\n    };\n\n    let output_shape_vec = resolve_output_shape(slice_id, &input_shape, n_dim, target_shape)?;\n\n    let final_result = stacked\n        .as_standard_layout()\n        .into_owned()\n        .into_shape_with_order(ndarray::IxDyn(&output_shape_vec))\n        .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: dim-split reshape: {e}\")))?;\n\n    tracing::info!(\n        slice = %slice_id,\n        rows = total_rows,\n        k_chunks = k_chunks,\n        \"executed dim-split (sequence + K tiled)\"\n    );\n\n    Ok(final_result)\n}\n\nfn execute_generic_dim_split(\n    slice_id: &str,\n    ds: &crate::schema::tiling::DimSplitInfo,\n    target_shape: Option<&[i64]>,\n    tensor_cache: &TensorStore,\n    tmpl_path: &Path,\n) -> Result<ndarray::ArrayD<f64>> {\n    use ndarray::Axis;\n\n    let concat_axis = ds.concat_axis;\n    let split_dim = ds.split_dim;\n    let epg = ds.elements_per_group;\n\n    let tmpl_model = crate::slicer::onnx_proto::load_model(tmpl_path)?;\n    let tmpl_graph = tmpl_model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Pipeline(format!(\"{slice_id}: template has no graph\")))?;\n    let tmpl_init_names: std::collections::HashSet<&str> = tmpl_graph\n        .initializer\n        .iter()\n        .map(|i| i.name.as_str())\n        .collect();\n    let input_names: Vec<String> = tmpl_graph\n        .input\n        .iter()\n        .filter(|vi| !tmpl_init_names.contains(vi.name.as_str()))\n        .map(|vi| vi.name.clone())\n        .collect();\n\n    let tmp_dir = tempfile::tempdir()\n        .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: tmpdir: {e}\")))?;\n    let tmpl_on_disk = tmp_dir.path().join(\"dim_tmpl.onnx\");\n    crate::slicer::onnx_proto::save_model(&tmpl_model, &tmpl_on_disk)?;\n\n    let mut group_outputs: Vec<ndarray::ArrayD<f64>> = Vec::new();\n\n    for g in 0..ds.num_groups {\n        let dim_start = g * epg;\n        if dim_start >= ds.dim_size {\n            break;\n        }\n        let dim_end = ((g + 1) * epg).min(ds.dim_size);\n        let actual_size = dim_end - dim_start;\n\n        // dim_size is required to be an exact multiple of epg by the\n        // detector (`smallest_divisor_at_least`), so every group is\n        // exactly `epg` wide and we can feed the sliced view straight\n        // in -- no zero-padding, no output trimming, no risk of\n        // contaminating reductions on non-split axes.\n        debug_assert_eq!(\n            actual_size, epg,\n            \"dim-split detector must enforce dim_size % epg == 0\"\n        );\n        let mut group_cache = TensorStore::new();\n        for vi_name in &input_names {\n            let arr = tensor_cache.try_get(vi_name).ok_or_else(|| {\n                DsperseError::Pipeline(format!(\n                    \"{slice_id}: template input {vi_name:?} not found in tensor cache\"\n                ))\n            })?;\n            let shape = arr.shape();\n            if split_dim < shape.len() && shape[split_dim] == ds.dim_size {\n                let sliced = arr\n                    .slice_axis(Axis(split_dim), ndarray::Slice::from(dim_start..dim_end))\n                    .to_owned();\n                group_cache.put(vi_name.clone(), sliced);\n            } else {\n                group_cache.put(vi_name.clone(), arr.clone());\n            }\n        }\n\n        let mut named = run_onnx_inference_multi_named(&tmpl_on_disk, &group_cache, &input_names)?;\n        let (data, shape) = named.remove(&ds.output_name).ok_or_else(|| {\n            DsperseError::Pipeline(format!(\n                \"{slice_id}: dim-split group {g} missing output {:?} (available: {:?})\",\n                ds.output_name,\n                named.keys().collect::<Vec<_>>()\n            ))\n        })?;\n        let group_output = ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shape), data)\n            .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: group {g} reshape: {e}\")))?;\n\n        // Output is naturally `epg` wide along concat_axis.\n        let trimmed = group_output;\n\n        group_outputs.push(trimmed);\n    }\n\n    let result = ndarray::concatenate(\n        Axis(concat_axis),\n        &group_outputs.iter().map(|a| a.view()).collect::<Vec<_>>(),\n    )\n    .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: dim-split concat: {e}\")))?;\n\n    let output_shape_vec = if let Some(target) = target_shape {\n        target\n            .iter()\n            .map(|&d| {\n                usize::try_from(d).map_err(|_| {\n                    DsperseError::Pipeline(format!(\n                        \"{slice_id}: invalid target dim {d} in dim-split reshape\"\n                    ))\n                })\n            })\n            .collect::<Result<Vec<_>>>()?\n    } else {\n        result.shape().to_vec()\n    };\n\n    let final_result = result\n        .as_standard_layout()\n        .into_owned()\n        .into_shape_with_order(ndarray::IxDyn(&output_shape_vec))\n        .map_err(|e| DsperseError::Pipeline(format!(\"{slice_id}: dim-split reshape: {e}\")))?;\n\n    tracing::info!(\n        slice = %slice_id,\n        groups = ds.num_groups,\n        split_kind = ?ds.split_kind,\n        \"executed dim-split (generic)\"\n    );\n\n    Ok(final_result)\n}\n\nfn resolve_output_shape(\n    slice_id: &str,\n    input_shape: &[usize],\n    n_dim: usize,\n    target_shape: Option<&[i64]>,\n) -> Result<Vec<usize>> {\n    if let Some(target) = target_shape {\n        target\n            .iter()\n            .map(|&d| {\n                usize::try_from(d).map_err(|_| {\n                    DsperseError::Pipeline(format!(\n                        \"{slice_id}: invalid target dimension {d} in dim-split reshape\"\n                    ))\n                })\n            })\n            .collect::<Result<Vec<_>>>()\n    } else {\n        let mut s = input_shape.to_vec();\n        if let Some(last) = s.last_mut() {\n            *last = n_dim;\n        }\n        Ok(s)\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/incremental.rs",
    "content": "use std::path::{Path, PathBuf};\n\nuse ndarray::ArrayD;\n\nuse crate::error::{DsperseError, Result};\nuse crate::schema::execution::{ExecutionChain, ExecutionInfo, ExecutionResultEntry, RunMetadata};\nuse crate::schema::metadata::{BackendKind, ModelMetadata, RunSliceMetadata};\nuse crate::schema::tiling::{ChannelSplitInfo, TilingInfo};\n\nuse super::runner::{build_execution_chain, build_run_metadata, load_model_metadata};\nuse super::strategy::ExecutionStrategy;\nuse super::tensor_store::TensorStore;\n\npub struct SliceWork {\n    pub slice_id: String,\n    pub input: ArrayD<f64>,\n    pub named_inputs: Vec<(String, ArrayD<f64>)>,\n    pub backend: BackendKind,\n    pub use_circuit: bool,\n    pub tiling: Option<TilingInfo>,\n    pub channel_split: Option<ChannelSplitInfo>,\n    pub circuit_path: Option<String>,\n    pub onnx_path: Option<String>,\n    pub slice_meta: RunSliceMetadata,\n}\n\npub struct SliceExecutionResult {\n    pub slice_id: String,\n    pub output: ArrayD<f64>,\n    pub execution_info: ExecutionInfo,\n}\n\npub struct IncrementalRun {\n    tensor_cache: TensorStore,\n    execution_chain: ExecutionChain,\n    model_meta: ModelMetadata,\n    run_meta: RunMetadata,\n    slices_dir: PathBuf,\n    current_slice: Option<String>,\n    results: Vec<ExecutionResultEntry>,\n}\n\nimpl IncrementalRun {\n    pub fn new(slices_dir: &Path, input: ArrayD<f64>) -> Result<Self> {\n        let model_meta = load_model_metadata(slices_dir)?;\n\n        let chain = build_execution_chain(&model_meta, slices_dir)?;\n        let run_meta = build_run_metadata(&model_meta, slices_dir, &chain)?;\n\n        let first_slice = model_meta\n            .slices\n            .first()\n            .ok_or_else(|| DsperseError::Pipeline(\"model has no slices\".into()))?;\n        let filtered = &first_slice.dependencies.filtered_inputs;\n        if filtered.len() != 1 {\n            return Err(DsperseError::Pipeline(format!(\n                \"multi-input models not supported: first slice declares {} filtered inputs\",\n                filtered.len()\n            )));\n        }\n        let input_name = filtered[0].clone();\n        let mut tensor_cache = TensorStore::new();\n        tensor_cache.put(input_name, input);\n\n        let current_slice = chain.head.clone();\n\n        Ok(Self {\n            tensor_cache,\n            execution_chain: chain,\n            model_meta,\n            run_meta,\n            slices_dir: slices_dir.to_path_buf(),\n            current_slice,\n            results: Vec::new(),\n        })\n    }\n\n    pub fn next_slice(&self) -> Result<Option<SliceWork>> {\n        let slice_id = match self.current_slice.as_ref() {\n            Some(id) => id,\n            None => return Ok(None),\n        };\n        let node = self.execution_chain.nodes.get(slice_id).ok_or_else(|| {\n            DsperseError::Pipeline(format!(\"execution chain missing node for {slice_id}\"))\n        })?;\n        let meta = self.run_meta.slices.get(slice_id).ok_or_else(|| {\n            DsperseError::Pipeline(format!(\"run metadata missing slice {slice_id}\"))\n        })?;\n\n        let strategy = ExecutionStrategy::from_metadata(meta, node.use_circuit)?;\n        let (input, named_inputs) = match strategy {\n            ExecutionStrategy::ChannelSplit(cs) => {\n                let t = self.tensor_cache.get(&cs.input_name)?.clone();\n                (t, Vec::new())\n            }\n            ExecutionStrategy::DimSplit(ds) => {\n                let t = self.tensor_cache.get(&ds.input_name)?.clone();\n                (t, Vec::new())\n            }\n            ExecutionStrategy::Tiled(tiling) => {\n                let t = self.tensor_cache.get(&tiling.input_name)?.clone();\n                (t, Vec::new())\n            }\n            ExecutionStrategy::Single { .. } => {\n                let filtered = &meta.dependencies.filtered_inputs;\n                let mut named = Vec::with_capacity(filtered.len());\n                for name in filtered {\n                    let arr = self.tensor_cache.get(name)?;\n                    named.push((name.clone(), arr.clone()));\n                }\n                let concatenated = self.tensor_cache.gather(filtered)?;\n                (concatenated, named)\n            }\n        };\n\n        Ok(Some(SliceWork {\n            slice_id: slice_id.clone(),\n            input,\n            named_inputs,\n            backend: node.backend,\n            use_circuit: node.use_circuit,\n            tiling: meta.tiling.clone(),\n            channel_split: meta.channel_split.clone(),\n            circuit_path: node.circuit_path.clone(),\n            onnx_path: node.onnx_path.clone(),\n            slice_meta: meta.clone(),\n        }))\n    }\n\n    pub fn apply_result(&mut self, result: SliceExecutionResult) -> Result<()> {\n        let slice_id = &result.slice_id;\n\n        match self.current_slice.as_deref() {\n            Some(expected) if expected != slice_id => {\n                return Err(DsperseError::Pipeline(format!(\n                    \"out-of-order result: expected {expected}, got {slice_id}\"\n                )));\n            }\n            None => {\n                return Err(DsperseError::Pipeline(format!(\n                    \"pipeline already complete, unexpected result for {slice_id}\"\n                )));\n            }\n            _ => {}\n        }\n\n        let meta = self\n            .run_meta\n            .slices\n            .get(slice_id)\n            .ok_or_else(|| DsperseError::Pipeline(format!(\"unknown slice {slice_id}\")))?;\n\n        let strategy = ExecutionStrategy::from_metadata(meta, false)?;\n        match strategy {\n            ExecutionStrategy::ChannelSplit(cs) => {\n                self.tensor_cache.put(cs.output_name.clone(), result.output);\n            }\n            ExecutionStrategy::DimSplit(ds) => {\n                self.tensor_cache.put(ds.output_name.clone(), result.output);\n            }\n            ExecutionStrategy::Tiled(tiling) => {\n                self.tensor_cache\n                    .put(tiling.output_name.clone(), result.output);\n            }\n            ExecutionStrategy::Single { .. } => {\n                if meta.dependencies.output.is_empty() {\n                    return Err(DsperseError::Pipeline(format!(\n                        \"slice {slice_id} has no output dependency names\"\n                    )));\n                }\n                for name in &meta.dependencies.output {\n                    self.tensor_cache.put(name.clone(), result.output.clone());\n                }\n            }\n        }\n\n        self.results.push(ExecutionResultEntry {\n            slice_id: slice_id.clone(),\n            witness_execution: Some(result.execution_info),\n            proof_execution: None,\n            verification_execution: None,\n        });\n\n        let next = self\n            .execution_chain\n            .nodes\n            .get(slice_id)\n            .and_then(|n| n.next.clone());\n        self.current_slice = next;\n\n        Ok(())\n    }\n\n    pub fn is_complete(&self) -> bool {\n        self.current_slice.is_none()\n    }\n\n    pub fn final_output(&self) -> Option<&ArrayD<f64>> {\n        let last_slice = self.model_meta.slices.last()?;\n        let slice_id = format!(\"slice_{}\", last_slice.index);\n        let meta = self.run_meta.slices.get(&slice_id)?;\n\n        let strategy = ExecutionStrategy::from_metadata(meta, false).ok()?;\n        match strategy.output_name() {\n            Some(name) => self.tensor_cache.try_get(name),\n            None => {\n                let output_name = meta.dependencies.output.first()?;\n                self.tensor_cache.try_get(output_name)\n            }\n        }\n    }\n\n    pub fn into_run_metadata(self) -> RunMetadata {\n        let mut meta = self.run_meta;\n        meta.execution_chain.execution_results = self.results;\n        meta.source_path = Some(self.slices_dir.to_string_lossy().into_owned());\n        meta\n    }\n\n    pub fn slices_dir(&self) -> &Path {\n        &self.slices_dir\n    }\n\n    pub fn model_meta(&self) -> &ModelMetadata {\n        &self.model_meta\n    }\n\n    pub fn run_meta(&self) -> &RunMetadata {\n        &self.run_meta\n    }\n\n    pub fn tensor_cache(&self) -> &TensorStore {\n        &self.tensor_cache\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/mod.rs",
    "content": "mod channel_split;\nmod combined;\nmod compiler;\nmod dim_split;\nmod incremental;\npub mod packager;\nmod prover;\npub mod publisher;\npub mod runner;\npub mod slice_cache;\nmod stage;\npub mod strategy;\npub mod tensor_store;\npub mod tile_executor;\nmod tiled;\nmod verifier;\n\npub use combined::CombinedRun;\npub use compiler::{\n    CompileReport, HolographicSetupReport, SliceAnalysisReport, analyze_slices, compile_slices,\n    setup_holographic_for_slices,\n};\npub use incremental::{IncrementalRun, SliceExecutionResult, SliceWork};\npub use prover::prove_run;\npub use runner::{RunConfig, extract_onnx_initializers, run_inference};\npub use slice_cache::SliceAssets;\npub use strategy::ExecutionStrategy;\npub use tensor_store::TensorStore;\npub use tiled::{reconstruct_from_tiles, split_for_tiling, split_into_tiles};\npub use verifier::verify_run;\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/packager.rs",
    "content": "use std::collections::HashSet;\nuse std::fs;\nuse std::io::Read;\nuse std::path::{Path, PathBuf};\n\nuse serde::Serialize;\nuse sha2::{Digest, Sha256};\nuse walkdir::WalkDir;\n\nuse crate::error::{DsperseError, Result};\nuse crate::pipeline::compiler::compute_bundle_signature;\nuse crate::pipeline::runner::load_model_metadata;\nuse crate::schema::metadata::SliceMetadata;\nuse crate::utils::paths::resolve_relative_path;\n\npub struct PackageConfig {\n    pub output_dir: PathBuf,\n    pub author: Option<String>,\n    pub model_version: Option<String>,\n    pub model_name: Option<String>,\n    pub timeout: Option<u64>,\n    pub curve: Option<String>,\n}\n\n#[derive(Debug)]\npub struct PackageResult {\n    pub component_count: usize,\n    pub wb_count: usize,\n    pub manifest_path: PathBuf,\n    pub total_size: u64,\n}\n\n#[derive(Serialize)]\nstruct ArtifactRef {\n    sha256: String,\n    role: String,\n    filename: String,\n    size_bytes: u64,\n}\n\n#[derive(Serialize)]\nstruct Manifest {\n    version: u32,\n    model: ModelInfo,\n    #[serde(default, skip_serializing_if = \"Vec::is_empty\")]\n    artifacts: Vec<ArtifactRef>,\n    components: Vec<ComponentEntry>,\n    dag: Vec<DagNode>,\n}\n\n#[derive(Serialize)]\nstruct ModelInfo {\n    name: String,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    curve: Option<String>,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    author: Option<String>,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    version: Option<String>,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    timeout: Option<u64>,\n    input_schema: InputSchema,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    dsperse_version: Option<String>,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    jstprove_version: Option<String>,\n}\n\n#[derive(Serialize)]\nstruct InputSchema {\n    shape: Vec<Vec<i64>>,\n    output_shapes: Vec<Vec<i64>>,\n    output_names: Vec<String>,\n}\n\n#[derive(Serialize)]\nstruct ComponentEntry {\n    index: usize,\n    name: String,\n    sha256: String,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    curve: Option<String>,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    proof_system: Option<String>,\n    files: Vec<String>,\n    weights: Vec<WeightRef>,\n}\n\n#[derive(Serialize)]\nstruct WeightRef {\n    sha256: String,\n    role: String,\n    filename: String,\n    size_bytes: u64,\n}\n\n#[derive(Serialize)]\nstruct DagNode {\n    component_index: usize,\n    inputs: Vec<String>,\n    outputs: Vec<String>,\n    input_shape: Vec<Vec<i64>>,\n    output_shape: Vec<Vec<i64>>,\n}\n\nconst VALID_CURVES: &[&str] = &[\n    \"bn254\",\n    \"goldilocks\",\n    \"goldilocks_basefold\",\n    \"goldilocks_ext2\",\n    \"goldilocks_whir\",\n    \"goldilocks_whir_pq\",\n];\n\nfn normalize_curve(curve: Option<&str>) -> Result<Option<String>> {\n    let Some(c) = curve else { return Ok(None) };\n    let c = c.trim().to_ascii_lowercase();\n    if c.is_empty() {\n        return Err(DsperseError::Other(\"curve must not be empty\".into()));\n    }\n    if !VALID_CURVES.contains(&c.as_str()) {\n        return Err(DsperseError::Other(format!(\n            \"unsupported curve {c:?}; expected one of: {}\",\n            VALID_CURVES.join(\", \")\n        )));\n    }\n    Ok(Some(c))\n}\n\npub fn package_content_addressed(\n    slices_dir: &Path,\n    config: &PackageConfig,\n) -> Result<PackageResult> {\n    if !slices_dir.is_dir() {\n        return Err(DsperseError::Other(format!(\n            \"slices directory not found: {}\",\n            slices_dir.display()\n        )));\n    }\n\n    let curve = normalize_curve(config.curve.as_deref())?;\n\n    let model_meta = load_model_metadata(slices_dir)?;\n\n    let components_dir = config.output_dir.join(\"components\");\n    let wb_dir = config.output_dir.join(\"wb\");\n    fs::create_dir_all(&components_dir).map_err(|e| DsperseError::io(e, &components_dir))?;\n    fs::create_dir_all(&wb_dir).map_err(|e| DsperseError::io(e, &wb_dir))?;\n\n    let mut components: Vec<ComponentEntry> = Vec::new();\n    let mut dag_nodes: Vec<DagNode> = Vec::new();\n    let mut written_components: HashSet<String> = HashSet::new();\n    let mut written_wbs: HashSet<String> = HashSet::new();\n    let mut total_size: u64 = 0;\n\n    for slice in &model_meta.slices {\n        let slice_dir = slices_dir.join(format!(\"slice_{}\", slice.index));\n\n        let (component_hash, component_files, proof_system, source) =\n            extract_component(slices_dir, slice, &slice_dir, curve.as_deref())?;\n\n        if !written_components.contains(&component_hash) {\n            let dest = components_dir.join(&component_hash);\n            fs::create_dir_all(&dest).map_err(|e| DsperseError::io(e, &dest))?;\n\n            match &source {\n                ComponentSource::CircuitBundle(circuit_dir) => {\n                    total_size += copy_files_flat(circuit_dir, &dest)?;\n                }\n                ComponentSource::OnnxFile(onnx_path) => {\n                    if let Some(filename) = component_files.first() {\n                        let dest_file = dest.join(filename);\n                        fs::copy(onnx_path, &dest_file)\n                            .map_err(|e| DsperseError::io(e, onnx_path))?;\n                        total_size += onnx_path\n                            .metadata()\n                            .map_err(|e| DsperseError::io(e, onnx_path))?\n                            .len();\n                    }\n                }\n            }\n            written_components.insert(component_hash.clone());\n        }\n\n        let mut weights: Vec<WeightRef> = Vec::new();\n        let payload_blobs = collect_payload_blobs(slices_dir, slice, &slice_dir)?;\n        for (role, filename, data) in &payload_blobs {\n            let hash = sha256_bytes(data);\n            if !written_wbs.contains(&hash) {\n                let wb_path = wb_dir.join(&hash);\n                fs::write(&wb_path, data).map_err(|e| DsperseError::io(e, &wb_path))?;\n                total_size += data.len() as u64;\n                written_wbs.insert(hash.clone());\n            }\n            weights.push(WeightRef {\n                sha256: hash,\n                role: role.clone(),\n                filename: filename.clone(),\n                size_bytes: data.len() as u64,\n            });\n        }\n\n        components.push(ComponentEntry {\n            index: slice.index,\n            name: format!(\"slice_{}\", slice.index),\n            sha256: component_hash,\n            curve: curve.clone(),\n            proof_system,\n            files: component_files,\n            weights,\n        });\n\n        dag_nodes.push(DagNode {\n            component_index: slice.index,\n            inputs: slice.dependencies.input.clone(),\n            outputs: slice.dependencies.output.clone(),\n            input_shape: slice.shape.tensor_shape.input.clone(),\n            output_shape: slice.shape.tensor_shape.output.clone(),\n        });\n\n        if (slice.index + 1) % 50 == 0 {\n            tracing::info!(\n                progress = slice.index + 1,\n                total = model_meta.slices.len(),\n                \"packaging slices\"\n            );\n        }\n    }\n\n    let mut artifacts: Vec<ArtifactRef> = Vec::new();\n    let model_artifact_files = [\"metadata.msgpack\", \"model.onnx\"];\n    for filename in &model_artifact_files {\n        let src = slices_dir.join(filename);\n        if !src.is_file() {\n            return Err(DsperseError::Other(format!(\n                \"required model artifact '{}' not found at {}\",\n                filename,\n                src.display()\n            )));\n        }\n        reject_symlink_path(&src)?;\n        let data = fs::read(&src).map_err(|e| DsperseError::io(e, &src))?;\n        let hash = sha256_bytes(&data);\n        if !written_wbs.contains(&hash) {\n            let wb_path = wb_dir.join(&hash);\n            fs::write(&wb_path, &data).map_err(|e| DsperseError::io(e, &wb_path))?;\n            total_size += data.len() as u64;\n            written_wbs.insert(hash.clone());\n        }\n        artifacts.push(ArtifactRef {\n            sha256: hash,\n            role: \"artifact\".to_string(),\n            filename: (*filename).to_string(),\n            size_bytes: data.len() as u64,\n        });\n        tracing::info!(filename, \"packaged model artifact\");\n    }\n\n    let model_name = config\n        .model_name\n        .clone()\n        .or_else(|| {\n            slices_dir\n                .parent()\n                .and_then(|p| p.file_name())\n                .and_then(|n| n.to_str())\n                .map(String::from)\n        })\n        .unwrap_or_else(|| \"unknown\".to_string());\n\n    let manifest = Manifest {\n        version: 1,\n        model: ModelInfo {\n            name: model_name,\n            curve: curve.clone(),\n            author: config.author.clone(),\n            version: config.model_version.clone(),\n            timeout: config.timeout,\n            input_schema: InputSchema {\n                shape: model_meta.input_shape,\n                output_shapes: model_meta.output_shapes,\n                output_names: model_meta.output_names,\n            },\n            dsperse_version: model_meta.dsperse_version,\n            jstprove_version: model_meta.jstprove_version,\n        },\n        artifacts,\n        components,\n        dag: dag_nodes,\n    };\n\n    let manifest_path = config.output_dir.join(\"manifest.msgpack\");\n    let manifest_bytes = rmp_serde::to_vec_named(&manifest)\n        .map_err(|e| DsperseError::Other(format!(\"failed to serialize manifest: {e}\")))?;\n    fs::write(&manifest_path, &manifest_bytes).map_err(|e| DsperseError::io(e, &manifest_path))?;\n    total_size += manifest_bytes.len() as u64;\n\n    Ok(PackageResult {\n        component_count: written_components.len(),\n        wb_count: written_wbs.len(),\n        manifest_path,\n        total_size,\n    })\n}\n\nfn resolve_circuit_dir(slices_dir: &Path, slice: &SliceMetadata) -> Result<Option<PathBuf>> {\n    let bundle = slices_dir\n        .join(format!(\"slice_{}\", slice.index))\n        .join(\"jstprove\")\n        .join(\"circuit.bundle\");\n    if bundle.is_dir() {\n        return Ok(Some(bundle));\n    }\n    if let Some(ref cs) = slice.channel_split\n        && let Some(group) = cs.groups.first()\n        && let Some(ref circuit_path) = group.jstprove_circuit_path\n    {\n        let abs = resolve_relative_path(slices_dir, circuit_path)?;\n        if abs.is_dir() {\n            return Ok(Some(abs));\n        }\n    }\n    if let Some(ref ds) = slice.dim_split\n        && let Some(ref circuit_path) = ds.jstprove_circuit_path\n    {\n        let abs = resolve_relative_path(slices_dir, circuit_path)?;\n        if abs.is_dir() {\n            return Ok(Some(abs));\n        }\n    }\n    Ok(None)\n}\n\nenum ComponentSource {\n    CircuitBundle(PathBuf),\n    OnnxFile(PathBuf),\n}\n\nfn resolve_source_onnx(slices_dir: &Path, slice: &SliceMetadata) -> Result<PathBuf> {\n    if let Some(ref cs) = slice.channel_split\n        && let Some(group) = cs.groups.first()\n    {\n        let p = resolve_relative_path(slices_dir, &group.path)?;\n        reject_symlink_path(&p)?;\n        if !p.is_file() {\n            return Err(DsperseError::Other(format!(\n                \"slice {} channel group ONNX configured but missing: {}\",\n                slice.index,\n                p.display()\n            )));\n        }\n        return Ok(p);\n    }\n    if let Some(ref ds) = slice.dim_split\n        && let Some(ref tmpl) = ds.template_path\n    {\n        let p = resolve_relative_path(slices_dir, tmpl)?;\n        reject_symlink_path(&p)?;\n        if !p.is_file() {\n            return Err(DsperseError::Other(format!(\n                \"slice {} dim-split template configured but missing: {}\",\n                slice.index,\n                p.display()\n            )));\n        }\n        return Ok(p);\n    }\n    if let Some(ref tiling) = slice.tiling\n        && let Some(ref tile) = tiling.tile\n    {\n        let p = resolve_relative_path(slices_dir, &tile.path)?;\n        reject_symlink_path(&p)?;\n        if !p.is_file() {\n            return Err(DsperseError::Other(format!(\n                \"slice {} tile ONNX configured but missing: {}\",\n                slice.index,\n                p.display()\n            )));\n        }\n        return Ok(p);\n    }\n    let p = slice.resolve_onnx(slices_dir)?;\n    reject_symlink_path(&p)?;\n    Ok(p)\n}\n\nfn list_bundle_files(dir: &Path) -> Result<Vec<String>> {\n    let mut files = Vec::new();\n    for entry in WalkDir::new(dir) {\n        let entry = entry.map_err(|e| DsperseError::Other(e.to_string()))?;\n        reject_symlink(&entry)?;\n        if entry.file_type().is_file() {\n            let relative = entry\n                .path()\n                .strip_prefix(dir)\n                .map_err(|e| DsperseError::Other(e.to_string()))?\n                .components()\n                .map(|c| match c {\n                    std::path::Component::Normal(part) => Ok(part.to_string_lossy().into_owned()),\n                    _ => Err(DsperseError::Other(\n                        \"unexpected non-normal path component in bundle\".into(),\n                    )),\n                })\n                .collect::<Result<Vec<_>>>()?\n                .join(\"/\");\n            files.push(relative);\n        }\n    }\n    files.sort();\n    Ok(files)\n}\n\nfn extract_component(\n    slices_dir: &Path,\n    slice: &SliceMetadata,\n    _slice_dir: &Path,\n    curve: Option<&str>,\n) -> Result<(String, Vec<String>, Option<String>, ComponentSource)> {\n    if let Some(dir) = resolve_circuit_dir(slices_dir, slice)? {\n        let onnx_path = resolve_source_onnx(slices_dir, slice)?;\n        let sig = compute_bundle_signature(&onnx_path, curve, &dir)?;\n        let files = list_bundle_files(&dir)?;\n        return Ok((\n            sig,\n            files,\n            Some(\"jstprove\".to_string()),\n            ComponentSource::CircuitBundle(dir),\n        ));\n    }\n\n    let onnx_path = slice.resolve_onnx(slices_dir)?;\n    reject_symlink_path(&onnx_path)?;\n    if onnx_path.is_file() {\n        let filename = onnx_path\n            .file_name()\n            .and_then(|n| n.to_str())\n            .unwrap_or(\"model.onnx\")\n            .to_string();\n        let hash = hash_named_file(&onnx_path, &filename, curve)?;\n        return Ok((\n            hash,\n            vec![filename],\n            None,\n            ComponentSource::OnnxFile(onnx_path),\n        ));\n    }\n\n    Err(DsperseError::Other(format!(\n        \"slice {} has no circuit directory or ONNX artifact to package\",\n        slice.index\n    )))\n}\n\nfn collect_payload_blobs(\n    slices_dir: &Path,\n    slice: &SliceMetadata,\n    slice_dir: &Path,\n) -> Result<Vec<(String, String, Vec<u8>)>> {\n    let mut blobs: Vec<(String, String, Vec<u8>)> = Vec::new();\n\n    let onnx_path = slice.resolve_onnx(slices_dir).unwrap_or_else(|_| {\n        slice_dir\n            .join(\"payload\")\n            .join(format!(\"slice_{}.onnx\", slice.index))\n    });\n    reject_symlink_path(&onnx_path)?;\n    if onnx_path.is_file() {\n        let data = fs::read(&onnx_path).map_err(|e| DsperseError::io(e, &onnx_path))?;\n        let filename = onnx_path\n            .file_name()\n            .and_then(|n| n.to_str())\n            .unwrap_or(\"model.onnx\")\n            .to_string();\n        blobs.push((\"payload\".to_string(), filename, data));\n    }\n\n    if let Some(ref cs) = slice.channel_split {\n        for group in &cs.groups {\n            let group_path = resolve_relative_path(slices_dir, &group.path)?;\n            reject_symlink_path(&group_path)?;\n            if group_path.is_file() {\n                let data = fs::read(&group_path).map_err(|e| DsperseError::io(e, &group_path))?;\n                let filename = group_path\n                    .file_name()\n                    .and_then(|n| n.to_str())\n                    .unwrap_or(\"group.onnx\")\n                    .to_string();\n                blobs.push((\"channel_group\".to_string(), filename, data));\n            }\n        }\n        if let Some(ref bias_path) = cs.bias_path {\n            let abs = resolve_relative_path(slices_dir, bias_path)?;\n            reject_symlink_path(&abs)?;\n            if abs.is_file() {\n                let data = fs::read(&abs).map_err(|e| DsperseError::io(e, &abs))?;\n                blobs.push((\"bias\".to_string(), \"bias.msgpack\".to_string(), data));\n            }\n        }\n    }\n\n    Ok(blobs)\n}\n\nfn reject_symlink_path(path: &Path) -> Result<()> {\n    if path\n        .symlink_metadata()\n        .is_ok_and(|m| m.file_type().is_symlink())\n    {\n        return Err(DsperseError::Other(format!(\n            \"symlinked file is not allowed: {}\",\n            path.display()\n        )));\n    }\n    Ok(())\n}\n\nfn reject_symlink(entry: &walkdir::DirEntry) -> Result<()> {\n    if entry.file_type().is_symlink() {\n        return Err(DsperseError::Other(format!(\n            \"symlinked bundle entry is not allowed: {}\",\n            entry.path().display()\n        )));\n    }\n    Ok(())\n}\n\nfn hash_named_file(path: &Path, filename: &str, curve: Option<&str>) -> Result<String> {\n    let mut hasher = Sha256::new();\n    if let Some(c) = curve {\n        let c_bytes = c.as_bytes();\n        hasher.update((c_bytes.len() as u64).to_le_bytes());\n        hasher.update(c_bytes);\n    }\n    let name_bytes = filename.as_bytes();\n    hasher.update((name_bytes.len() as u64).to_le_bytes());\n    hasher.update(name_bytes);\n    let mut file = fs::File::open(path).map_err(|e| DsperseError::io(e, path))?;\n    let file_len = file\n        .metadata()\n        .map_err(|e| DsperseError::io(e, path))?\n        .len();\n    hasher.update(file_len.to_le_bytes());\n    let mut buf = [0u8; 8192];\n    loop {\n        let n = file.read(&mut buf).map_err(|e| DsperseError::io(e, path))?;\n        if n == 0 {\n            break;\n        }\n        hasher.update(&buf[..n]);\n    }\n    Ok(encode_hex(&hasher.finalize()))\n}\n\nfn sha256_bytes(data: &[u8]) -> String {\n    let mut hasher = Sha256::new();\n    hasher.update(data);\n    encode_hex(&hasher.finalize())\n}\n\nfn encode_hex(bytes: &[u8]) -> String {\n    let mut s = String::with_capacity(bytes.len() * 2);\n    for b in bytes {\n        use std::fmt::Write;\n        write!(s, \"{:02x}\", b).unwrap();\n    }\n    s\n}\n\nfn copy_files_flat(source_dir: &Path, dest_dir: &Path) -> Result<u64> {\n    let mut total: u64 = 0;\n    for entry in WalkDir::new(source_dir) {\n        let entry = entry.map_err(|e| DsperseError::Other(e.to_string()))?;\n        reject_symlink(&entry)?;\n        if entry.file_type().is_file() {\n            let relative = entry\n                .path()\n                .strip_prefix(source_dir)\n                .map_err(|e| DsperseError::Other(e.to_string()))?;\n            let dest_path = dest_dir.join(relative);\n            if let Some(parent) = dest_path.parent() {\n                fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?;\n            }\n            fs::copy(entry.path(), &dest_path).map_err(|e| DsperseError::io(e, entry.path()))?;\n            total += entry\n                .path()\n                .metadata()\n                .map_err(|e| DsperseError::io(e, entry.path()))?\n                .len();\n        }\n    }\n    Ok(total)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use std::fs;\n    use tempfile::TempDir;\n\n    use crate::schema::metadata::{\n        Compilation, Dependencies, ModelMetadata, SliceShapeWrapper, TensorShape,\n    };\n    use crate::slicer::onnx_proto;\n\n    fn write_minimal_onnx(path: &Path, input_dim: i64) {\n        let node = onnx_proto::NodeProto {\n            op_type: \"Relu\".to_string(),\n            input: vec![\"x\".to_string()],\n            output: vec![\"y\".to_string()],\n            ..Default::default()\n        };\n        let graph = onnx_proto::make_graph(\n            \"g\",\n            vec![node],\n            vec![onnx_proto::make_tensor_value_info(\"x\", 1, &[1, input_dim])],\n            vec![onnx_proto::make_tensor_value_info(\"y\", 1, &[1, input_dim])],\n            vec![],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n        onnx_proto::save_model(&model, path).unwrap();\n    }\n\n    fn create_test_model_metadata(slices_dir: &Path, count: usize) {\n        let mut slices = Vec::new();\n        for i in 0..count {\n            let slice_dir = slices_dir.join(format!(\"slice_{}\", i));\n            let payload_dir = slice_dir.join(\"payload\");\n            fs::create_dir_all(&payload_dir).unwrap();\n            write_minimal_onnx(\n                &payload_dir.join(format!(\"slice_{}.onnx\", i)),\n                (64 + i) as i64,\n            );\n\n            let circuit_dir = slice_dir.join(\"jstprove\").join(\"circuit.bundle\");\n            fs::create_dir_all(&circuit_dir).unwrap();\n            fs::write(circuit_dir.join(\"circuit.bin\"), format!(\"circuit_{}\", i)).unwrap();\n            fs::write(\n                circuit_dir.join(\"settings.json\"),\n                format!(\"{{\\\"idx\\\":{}}}\", i),\n            )\n            .unwrap();\n\n            let inputs = if i == 0 {\n                vec![\"model_input\".to_string()]\n            } else {\n                vec![format!(\"tensor_{}\", i - 1)]\n            };\n            let outputs = vec![format!(\"tensor_{}\", i)];\n\n            slices.push(SliceMetadata {\n                index: i,\n                filename: format!(\"slice_{}.onnx\", i),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: format!(\"slice_{}/payload/slice_{}.onnx\", i, i),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1, 3, 224, 224]],\n                        output: vec![vec![1, 64, 112, 112]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: inputs,\n                    output: outputs,\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            });\n        }\n\n        let meta = ModelMetadata {\n            original_model: \"test_model\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1, 3, 224, 224]],\n            output_shapes: vec![vec![1, 1000]],\n            output_names: vec![\"output\".to_string()],\n            slice_points: (0..count).collect(),\n            slices,\n            dsperse_version: Some(\"0.0.1-test\".to_string()),\n            dsperse_rev: None,\n            jstprove_version: Some(\"0.1.0-test\".to_string()),\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(slices_dir);\n    }\n\n    fn ensure_test_artifacts(slices_dir: &Path) {\n        let p = slices_dir.join(\"model.onnx\");\n        if !p.exists() {\n            fs::write(&p, b\"fake-onnx-for-test\").unwrap();\n        }\n    }\n\n    #[test]\n    fn test_content_addressed_output_structure() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 3);\n\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir: output_dir.clone(),\n\n            author: Some(\"test-author\".to_string()),\n            model_version: Some(\"1.0.0\".to_string()),\n            model_name: Some(\"test-model\".to_string()),\n            timeout: Some(300),\n            curve: None,\n        };\n\n        let result = package_content_addressed(&slices_dir, &config).unwrap();\n\n        assert_eq!(result.component_count, 3);\n        assert_eq!(result.wb_count, 5);\n        assert!(result.total_size > 0);\n        assert!(output_dir.join(\"components\").is_dir());\n        assert!(output_dir.join(\"wb\").is_dir());\n        assert!(output_dir.join(\"manifest.msgpack\").is_file());\n\n        let manifest_bytes = fs::read(output_dir.join(\"manifest.msgpack\")).unwrap();\n        let manifest: serde_json::Value = rmp_serde::from_slice(&manifest_bytes).unwrap();\n        let arts = manifest[\"artifacts\"].as_array().unwrap();\n        assert_eq!(arts.len(), 2);\n        let filenames: Vec<&str> = arts.iter().filter_map(|a| a[\"filename\"].as_str()).collect();\n        assert!(filenames.contains(&\"metadata.msgpack\"));\n        assert!(filenames.contains(&\"model.onnx\"));\n        for art in arts {\n            assert_eq!(art[\"role\"].as_str().unwrap(), \"artifact\");\n            assert!(art[\"sha256\"].as_str().unwrap().len() == 64);\n            assert!(art[\"size_bytes\"].as_u64().unwrap() > 0);\n        }\n    }\n\n    #[test]\n    fn test_missing_model_onnx_fails() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 1);\n        fs::remove_file(slices_dir.join(\"model.onnx\")).unwrap();\n\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir,\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n        let err = package_content_addressed(&slices_dir, &config).unwrap_err();\n        assert!(err.to_string().contains(\"model.onnx\"));\n    }\n\n    #[cfg(unix)]\n    #[test]\n    fn test_symlinked_artifact_rejected() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 1);\n        fs::remove_file(slices_dir.join(\"model.onnx\")).unwrap();\n        let target = tmp.path().join(\"evil.bin\");\n        fs::write(&target, b\"evil\").unwrap();\n        std::os::unix::fs::symlink(&target, slices_dir.join(\"model.onnx\")).unwrap();\n\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir,\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n        let err = package_content_addressed(&slices_dir, &config).unwrap_err();\n        assert!(err.to_string().contains(\"symlink\"));\n    }\n\n    #[test]\n    fn test_manifest_structure() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 2);\n\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir: output_dir.clone(),\n\n            author: Some(\"test-author\".to_string()),\n            model_version: Some(\"1.0.0\".to_string()),\n            model_name: Some(\"test-model\".to_string()),\n            timeout: Some(300),\n            curve: None,\n        };\n\n        package_content_addressed(&slices_dir, &config).unwrap();\n\n        let manifest: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(output_dir.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        assert_eq!(manifest[\"version\"], 1);\n        assert_eq!(manifest[\"model\"][\"name\"], \"test-model\");\n        assert_eq!(manifest[\"model\"][\"author\"], \"test-author\");\n        assert_eq!(manifest[\"model\"][\"version\"], \"1.0.0\");\n        assert_eq!(manifest[\"model\"][\"timeout\"], 300);\n\n        let components = manifest[\"components\"].as_array().unwrap();\n        assert_eq!(components.len(), 2);\n        for comp in components {\n            let sha = comp[\"sha256\"].as_str().unwrap();\n            assert_eq!(sha.len(), 64);\n            assert!(!comp[\"files\"].as_array().unwrap().is_empty());\n            assert_eq!(comp[\"proof_system\"], \"jstprove\");\n            assert!(!comp[\"weights\"].as_array().unwrap().is_empty());\n        }\n\n        let dag = manifest[\"dag\"].as_array().unwrap();\n        assert_eq!(dag.len(), 2);\n        assert_eq!(dag[0][\"inputs\"][0], \"model_input\");\n        assert_eq!(dag[0][\"outputs\"][0], \"tensor_0\");\n        assert_eq!(dag[1][\"inputs\"][0], \"tensor_0\");\n    }\n\n    #[test]\n    fn test_component_files_exist() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 1);\n\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir: output_dir.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        package_content_addressed(&slices_dir, &config).unwrap();\n\n        let manifest: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(output_dir.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        let comp = &manifest[\"components\"][0];\n        let sha = comp[\"sha256\"].as_str().unwrap();\n        let comp_dir = output_dir.join(\"components\").join(sha);\n        assert!(comp_dir.is_dir());\n        assert!(comp_dir.join(\"circuit.bin\").is_file());\n        assert!(comp_dir.join(\"settings.json\").is_file());\n    }\n\n    #[test]\n    fn test_wb_files_exist() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 1);\n\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir: output_dir.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        package_content_addressed(&slices_dir, &config).unwrap();\n\n        let manifest: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(output_dir.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        let weight = &manifest[\"components\"][0][\"weights\"][0];\n        let sha = weight[\"sha256\"].as_str().unwrap();\n        let wb_path = output_dir.join(\"wb\").join(sha);\n        assert!(wb_path.is_file());\n\n        let size = weight[\"size_bytes\"].as_u64().unwrap();\n        assert_eq!(fs::metadata(&wb_path).unwrap().len(), size);\n    }\n\n    #[test]\n    fn test_hash_determinism() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 2);\n\n        let out1 = tmp.path().join(\"out1\");\n        let out2 = tmp.path().join(\"out2\");\n\n        let config1 = PackageConfig {\n            output_dir: out1.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n        let config2 = PackageConfig {\n            output_dir: out2.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        package_content_addressed(&slices_dir, &config1).unwrap();\n        package_content_addressed(&slices_dir, &config2).unwrap();\n\n        let m1: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out1.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let m2: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out2.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        for i in 0..2 {\n            assert_eq!(m1[\"components\"][i][\"sha256\"], m2[\"components\"][i][\"sha256\"]);\n        }\n    }\n\n    #[test]\n    fn test_curve_changes_hash() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 2);\n\n        let out_none = tmp.path().join(\"out_none\");\n        let out_bn = tmp.path().join(\"out_bn\");\n        let out_gl = tmp.path().join(\"out_gl\");\n\n        let config_none = PackageConfig {\n            output_dir: out_none.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n        let config_bn = PackageConfig {\n            output_dir: out_bn.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"bn254\".to_string()),\n        };\n        let config_gl = PackageConfig {\n            output_dir: out_gl.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"goldilocks\".to_string()),\n        };\n\n        package_content_addressed(&slices_dir, &config_none).unwrap();\n        package_content_addressed(&slices_dir, &config_bn).unwrap();\n        package_content_addressed(&slices_dir, &config_gl).unwrap();\n\n        let m_none: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out_none.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let m_bn: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out_bn.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let m_gl: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out_gl.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        for i in 0..2 {\n            let h_none = m_none[\"components\"][i][\"sha256\"].as_str().unwrap();\n            let h_bn = m_bn[\"components\"][i][\"sha256\"].as_str().unwrap();\n            let h_gl = m_gl[\"components\"][i][\"sha256\"].as_str().unwrap();\n            assert_ne!(h_none, h_bn, \"curve=None vs bn254 should differ\");\n            assert_ne!(h_none, h_gl, \"curve=None vs goldilocks should differ\");\n            assert_ne!(h_bn, h_gl, \"bn254 vs goldilocks should differ\");\n        }\n    }\n\n    #[test]\n    fn test_curve_changes_hash_uncompiled_onnx() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n\n        let slice_dir = slices_dir.join(\"slice_0\");\n        let payload_dir = slice_dir.join(\"payload\");\n        fs::create_dir_all(&payload_dir).unwrap();\n        fs::write(payload_dir.join(\"slice_0.onnx\"), \"onnx_payload\").unwrap();\n\n        let meta = ModelMetadata {\n            original_model: \"test\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1, 3]],\n            output_shapes: vec![vec![1, 3]],\n            output_names: vec![\"out\".to_string()],\n            slice_points: vec![0],\n            slices: vec![SliceMetadata {\n                index: 0,\n                filename: \"slice_0.onnx\".to_string(),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: \"slice_0/payload/slice_0.onnx\".to_string(),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1, 3]],\n                        output: vec![vec![1, 3]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![\"in\".to_string()],\n                    output: vec![\"out\".to_string()],\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            }],\n            dsperse_version: None,\n            dsperse_rev: None,\n            jstprove_version: None,\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(&slices_dir);\n        let out_none = tmp.path().join(\"out_none\");\n        let out_bn = tmp.path().join(\"out_bn\");\n        let out_gl = tmp.path().join(\"out_gl\");\n\n        let config_none = PackageConfig {\n            output_dir: out_none.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n        let config_bn = PackageConfig {\n            output_dir: out_bn.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"bn254\".to_string()),\n        };\n        let config_gl = PackageConfig {\n            output_dir: out_gl.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"goldilocks\".to_string()),\n        };\n\n        package_content_addressed(&slices_dir, &config_none).unwrap();\n        package_content_addressed(&slices_dir, &config_bn).unwrap();\n        package_content_addressed(&slices_dir, &config_gl).unwrap();\n\n        let m_none: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out_none.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let m_bn: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out_bn.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let m_gl: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out_gl.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        let h_none = m_none[\"components\"][0][\"sha256\"].as_str().unwrap();\n        let h_bn = m_bn[\"components\"][0][\"sha256\"].as_str().unwrap();\n        let h_gl = m_gl[\"components\"][0][\"sha256\"].as_str().unwrap();\n        assert_ne!(h_none, h_bn, \"onnx: curve=None vs bn254 should differ\");\n        assert_ne!(h_none, h_gl, \"onnx: curve=None vs goldilocks should differ\");\n        assert_ne!(h_bn, h_gl, \"onnx: bn254 vs goldilocks should differ\");\n    }\n\n    #[test]\n    fn test_invalid_curve_rejected() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 1);\n\n        let config_typo = PackageConfig {\n            output_dir: tmp.path().join(\"output\"),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"bm254\".to_string()),\n        };\n        let result = package_content_addressed(&slices_dir, &config_typo);\n        assert!(result.is_err());\n\n        let config_empty = PackageConfig {\n            output_dir: tmp.path().join(\"output2\"),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"\".to_string()),\n        };\n        let result = package_content_addressed(&slices_dir, &config_empty);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn test_curve_normalization() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n        create_test_model_metadata(&slices_dir, 1);\n\n        let out1 = tmp.path().join(\"out1\");\n        let out2 = tmp.path().join(\"out2\");\n        let out3 = tmp.path().join(\"out3\");\n\n        let config1 = PackageConfig {\n            output_dir: out1.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"bn254\".to_string()),\n        };\n        let config2 = PackageConfig {\n            output_dir: out2.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\" bn254 \".to_string()),\n        };\n        let config3 = PackageConfig {\n            output_dir: out3.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: Some(\"BN254\".to_string()),\n        };\n\n        package_content_addressed(&slices_dir, &config1).unwrap();\n        package_content_addressed(&slices_dir, &config2).unwrap();\n        package_content_addressed(&slices_dir, &config3).unwrap();\n\n        let m1: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out1.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let m2: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out2.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let m3: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(out3.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        assert_eq!(m1[\"components\"][0][\"sha256\"], m2[\"components\"][0][\"sha256\"]);\n        assert_eq!(m1[\"components\"][0][\"sha256\"], m3[\"components\"][0][\"sha256\"]);\n    }\n\n    #[test]\n    fn test_deduplication_shared_circuits() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n\n        let mut slices = Vec::new();\n\n        for i in 0..3 {\n            let slice_dir = slices_dir.join(format!(\"slice_{}\", i));\n            let payload_dir = slice_dir.join(\"payload\");\n            fs::create_dir_all(&payload_dir).unwrap();\n            write_minimal_onnx(&payload_dir.join(format!(\"slice_{}.onnx\", i)), 64);\n            let circuit_dir = slice_dir.join(\"jstprove\").join(\"circuit.bundle\");\n            fs::create_dir_all(&circuit_dir).unwrap();\n            fs::write(circuit_dir.join(\"circuit.bin\"), \"shared_circuit_data\").unwrap();\n\n            slices.push(SliceMetadata {\n                index: i,\n                filename: format!(\"slice_{}.onnx\", i),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: format!(\"slice_{}/payload/slice_{}.onnx\", i, i),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1, 64]],\n                        output: vec![vec![1, 64]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![format!(\"t_{}\", i)],\n                    output: vec![format!(\"t_{}\", i + 1)],\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            });\n        }\n\n        let meta = ModelMetadata {\n            original_model: \"shared_test\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1, 64]],\n            output_shapes: vec![vec![1, 64]],\n            output_names: vec![\"out\".to_string()],\n            slice_points: vec![0, 1, 2],\n            slices,\n            dsperse_version: None,\n            dsperse_rev: None,\n            jstprove_version: None,\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(&slices_dir);\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir: output_dir.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        let result = package_content_addressed(&slices_dir, &config).unwrap();\n\n        assert_eq!(result.component_count, 1);\n        assert_eq!(result.wb_count, 3);\n\n        let manifest: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(output_dir.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let components = manifest[\"components\"].as_array().unwrap();\n        let hash0 = components[0][\"sha256\"].as_str().unwrap();\n        let hash1 = components[1][\"sha256\"].as_str().unwrap();\n        let hash2 = components[2][\"sha256\"].as_str().unwrap();\n        assert_eq!(hash0, hash1);\n        assert_eq!(hash1, hash2);\n    }\n\n    #[test]\n    fn test_uncompiled_onnx_only_slice() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n\n        let slice_dir = slices_dir.join(\"slice_0\");\n        let payload_dir = slice_dir.join(\"payload\");\n        fs::create_dir_all(&payload_dir).unwrap();\n        fs::write(payload_dir.join(\"slice_0.onnx\"), \"onnx_payload_data\").unwrap();\n\n        let meta = ModelMetadata {\n            original_model: \"test\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1, 3, 224, 224]],\n            output_shapes: vec![vec![1, 1000]],\n            output_names: vec![\"output\".to_string()],\n            slice_points: vec![0],\n            slices: vec![SliceMetadata {\n                index: 0,\n                filename: \"slice_0.onnx\".to_string(),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: \"slice_0/payload/slice_0.onnx\".to_string(),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1, 3, 224, 224]],\n                        output: vec![vec![1, 1000]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![\"input\".to_string()],\n                    output: vec![\"output\".to_string()],\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            }],\n            dsperse_version: None,\n            dsperse_rev: None,\n            jstprove_version: None,\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(&slices_dir);\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir: output_dir.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        let result = package_content_addressed(&slices_dir, &config).unwrap();\n        assert_eq!(result.component_count, 1);\n\n        let manifest: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(output_dir.join(\"manifest.msgpack\")).unwrap()).unwrap();\n\n        let comp = &manifest[\"components\"][0];\n        assert!(comp[\"proof_system\"].is_null());\n        let sha = comp[\"sha256\"].as_str().unwrap();\n        let files = comp[\"files\"].as_array().unwrap();\n        assert_eq!(files.len(), 1);\n        assert_eq!(files[0], \"slice_0.onnx\");\n\n        let comp_dir = output_dir.join(\"components\").join(sha);\n        assert!(comp_dir.join(\"slice_0.onnx\").is_file());\n    }\n\n    #[test]\n    fn test_missing_artifact_errors() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n\n        let slice_dir = slices_dir.join(\"slice_0\");\n        fs::create_dir_all(&slice_dir).unwrap();\n\n        let meta = ModelMetadata {\n            original_model: \"test\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1]],\n            output_shapes: vec![vec![1]],\n            output_names: vec![\"out\".to_string()],\n            slice_points: vec![0],\n            slices: vec![SliceMetadata {\n                index: 0,\n                filename: \"slice_0.onnx\".to_string(),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: \"slice_0/payload/slice_0.onnx\".to_string(),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1]],\n                        output: vec![vec![1]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![\"in\".to_string()],\n                    output: vec![\"out\".to_string()],\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            }],\n            dsperse_version: None,\n            dsperse_rev: None,\n            jstprove_version: None,\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(&slices_dir);\n        let config = PackageConfig {\n            output_dir: tmp.path().join(\"output\"),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        let result = package_content_addressed(&slices_dir, &config);\n        assert!(result.is_err());\n        let err = result.unwrap_err().to_string();\n        assert!(\n            err.contains(\"no circuit directory or ONNX artifact\"),\n            \"unexpected error: {err}\"\n        );\n    }\n\n    #[test]\n    fn test_path_traversal_rejected() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n\n        let slice_dir = slices_dir.join(\"slice_0\");\n        let payload_dir = slice_dir.join(\"payload\");\n        fs::create_dir_all(&payload_dir).unwrap();\n        fs::write(payload_dir.join(\"slice_0.onnx\"), \"data\").unwrap();\n\n        let meta = ModelMetadata {\n            original_model: \"test\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1]],\n            output_shapes: vec![vec![1]],\n            output_names: vec![\"out\".to_string()],\n            slice_points: vec![0],\n            slices: vec![SliceMetadata {\n                index: 0,\n                filename: \"slice_0.onnx\".to_string(),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: \"../../etc/passwd\".to_string(),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1]],\n                        output: vec![vec![1]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![\"in\".to_string()],\n                    output: vec![\"out\".to_string()],\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            }],\n            dsperse_version: None,\n            dsperse_rev: None,\n            jstprove_version: None,\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(&slices_dir);\n        let config = PackageConfig {\n            output_dir: tmp.path().join(\"output\"),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        let result = package_content_addressed(&slices_dir, &config);\n        assert!(result.is_err());\n        let err = result.unwrap_err().to_string();\n        assert!(\n            err.contains(\"path traversal\"),\n            \"expected path traversal error, got: {err}\"\n        );\n    }\n\n    #[test]\n    fn test_nonexistent_dir() {\n        let config = PackageConfig {\n            output_dir: PathBuf::from(\"/tmp/nonexistent_output\"),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n        let result = package_content_addressed(Path::new(\"/nonexistent/path\"), &config);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn test_identical_bytes_different_filenames_distinct_hashes() {\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n\n        let identical_data = \"identical_onnx_content\";\n\n        let mut slices = Vec::new();\n        for i in 0..2 {\n            let slice_dir = slices_dir.join(format!(\"slice_{}\", i));\n            let payload_dir = slice_dir.join(\"payload\");\n            fs::create_dir_all(&payload_dir).unwrap();\n            fs::write(\n                payload_dir.join(format!(\"slice_{}.onnx\", i)),\n                identical_data,\n            )\n            .unwrap();\n\n            slices.push(SliceMetadata {\n                index: i,\n                filename: format!(\"slice_{}.onnx\", i),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: format!(\"slice_{}/payload/slice_{}.onnx\", i, i),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1]],\n                        output: vec![vec![1]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![format!(\"t_{}\", i)],\n                    output: vec![format!(\"t_{}\", i + 1)],\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            });\n        }\n\n        let meta = ModelMetadata {\n            original_model: \"test\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1]],\n            output_shapes: vec![vec![1]],\n            output_names: vec![\"out\".to_string()],\n            slice_points: vec![0, 1],\n            slices,\n            dsperse_version: None,\n            dsperse_rev: None,\n            jstprove_version: None,\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(&slices_dir);\n        let output_dir = tmp.path().join(\"output\");\n        let config = PackageConfig {\n            output_dir: output_dir.clone(),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        let result = package_content_addressed(&slices_dir, &config).unwrap();\n        assert_eq!(result.component_count, 2);\n\n        let manifest: serde_json::Value =\n            rmp_serde::from_slice(&fs::read(output_dir.join(\"manifest.msgpack\")).unwrap()).unwrap();\n        let c0 = &manifest[\"components\"][0];\n        let c1 = &manifest[\"components\"][1];\n        assert_ne!(c0[\"sha256\"], c1[\"sha256\"]);\n\n        let dir0 = output_dir\n            .join(\"components\")\n            .join(c0[\"sha256\"].as_str().unwrap());\n        let dir1 = output_dir\n            .join(\"components\")\n            .join(c1[\"sha256\"].as_str().unwrap());\n        assert!(dir0.join(\"slice_0.onnx\").is_file());\n        assert!(dir1.join(\"slice_1.onnx\").is_file());\n    }\n\n    #[test]\n    #[cfg(unix)]\n    fn test_symlink_payload_rejected() {\n        use std::os::unix::fs::symlink;\n\n        let tmp = TempDir::new().unwrap();\n        let slices_dir = tmp.path().join(\"model\").join(\"slices\");\n        fs::create_dir_all(&slices_dir).unwrap();\n\n        let external = tmp.path().join(\"external_secret.bin\");\n        fs::write(&external, \"sensitive data\").unwrap();\n\n        let slice_dir = slices_dir.join(\"slice_0\");\n        let payload_dir = slice_dir.join(\"payload\");\n        fs::create_dir_all(&payload_dir).unwrap();\n        symlink(&external, payload_dir.join(\"slice_0.onnx\")).unwrap();\n\n        let meta = ModelMetadata {\n            original_model: \"test\".to_string(),\n            model_type: \"onnx\".to_string(),\n            input_shape: vec![vec![1]],\n            output_shapes: vec![vec![1]],\n            output_names: vec![\"out\".to_string()],\n            slice_points: vec![0],\n            slices: vec![SliceMetadata {\n                index: 0,\n                filename: \"slice_0.onnx\".to_string(),\n                path: slice_dir.to_string_lossy().to_string(),\n                relative_path: \"slice_0/payload/slice_0.onnx\".to_string(),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![vec![1]],\n                        output: vec![vec![1]],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![\"in\".to_string()],\n                    output: vec![\"out\".to_string()],\n                    filtered_inputs: vec![],\n                },\n                tiling: None,\n                channel_split: None,\n                dim_split: None,\n                compilation: Compilation::default(),\n                slice_metadata: None,\n                slice_metadata_relative_path: None,\n            }],\n            dsperse_version: None,\n            dsperse_rev: None,\n            jstprove_version: None,\n            jstprove_rev: None,\n            traced_shapes: None,\n            traced_types: None,\n            original_model_path: None,\n            folded_constant_names: vec![],\n        };\n        meta.save(&slices_dir.join(\"metadata.msgpack\")).unwrap();\n        ensure_test_artifacts(&slices_dir);\n        let config = PackageConfig {\n            output_dir: tmp.path().join(\"output\"),\n\n            author: None,\n            model_version: None,\n            model_name: None,\n            timeout: None,\n            curve: None,\n        };\n\n        let result = package_content_addressed(&slices_dir, &config);\n        assert!(result.is_err());\n        let err = result.unwrap_err().to_string();\n        assert!(\n            err.contains(\"symlink\"),\n            \"expected symlink error, got: {err}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/prover.rs",
    "content": "use std::path::Path;\n\nuse crate::backend::ProofBackend;\nuse crate::error::Result;\nuse crate::schema::execution::RunMetadata;\n\nuse super::stage::{PipelineStage, run_pipeline_stage};\n\npub fn prove_run(\n    run_dir: &Path,\n    slices_dir: &Path,\n    backend: &dyn ProofBackend,\n    parallel: usize,\n) -> Result<RunMetadata> {\n    run_pipeline_stage(PipelineStage::Prove, run_dir, slices_dir, backend, parallel)\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/publisher.rs",
    "content": "use std::fs;\nuse std::path::Path;\nuse std::time::Duration;\n\nuse sha2::{Digest, Sha256};\n\nuse crate::error::{DsperseError, Result};\n\nconst REQUEST_TIMEOUT: Duration = Duration::from_secs(30);\nconst UPLOAD_TIMEOUT: Duration = Duration::from_secs(300);\n\npub struct PublishConfig {\n    pub api_url: String,\n    pub auth_token: String,\n    pub name: String,\n    pub description: String,\n    pub author: String,\n    pub version: String,\n    pub proof_system: String,\n    pub timeout: u64,\n    pub activate: bool,\n}\n\npub struct PublishResult {\n    pub model_id: String,\n    pub components_uploaded: usize,\n    pub components_skipped: usize,\n    pub weights_uploaded: usize,\n    pub weights_skipped: usize,\n}\n\nfn auth_header(token: &str) -> String {\n    format!(\"Bearer {token}\")\n}\n\npub fn publish(dir: &Path, config: &PublishConfig) -> Result<PublishResult> {\n    let rt = tokio::runtime::Builder::new_current_thread()\n        .enable_all()\n        .build()\n        .map_err(|e| DsperseError::Other(format!(\"tokio runtime: {e}\")))?;\n\n    rt.block_on(publish_async(dir, config))\n}\n\nasync fn publish_async(dir: &Path, config: &PublishConfig) -> Result<PublishResult> {\n    let manifest_path = dir.join(\"manifest.msgpack\");\n    if !manifest_path.is_file() {\n        return Err(DsperseError::Other(format!(\n            \"manifest.msgpack not found in {}\",\n            dir.display()\n        )));\n    }\n\n    let manifest_bytes =\n        fs::read(&manifest_path).map_err(|e| DsperseError::io(e, &manifest_path))?;\n    let manifest: serde_json::Value = rmp_serde::from_slice(&manifest_bytes)\n        .map_err(|e| DsperseError::Other(format!(\"failed to parse manifest: {e}\")))?;\n\n    let components = manifest[\"components\"]\n        .as_array()\n        .ok_or_else(|| DsperseError::Other(\"manifest missing components array\".into()))?;\n\n    let dag = manifest[\"dag\"]\n        .as_array()\n        .ok_or_else(|| DsperseError::Other(\"manifest missing dag array\".into()))?;\n\n    let client = reqwest::Client::builder()\n        .timeout(REQUEST_TIMEOUT)\n        .build()\n        .map_err(|e| DsperseError::Other(format!(\"http client: {e}\")))?;\n    let api = config.api_url.trim_end_matches('/');\n    let auth = auth_header(&config.auth_token);\n\n    let mut components_uploaded = 0usize;\n    let mut components_skipped = 0usize;\n    let mut weights_uploaded = 0usize;\n    let mut weights_skipped = 0usize;\n\n    for comp in components {\n        let sha = comp[\"sha256\"]\n            .as_str()\n            .ok_or_else(|| DsperseError::Other(\"component missing sha256\".into()))?;\n        let files: Vec<String> = comp[\"files\"]\n            .as_array()\n            .ok_or_else(|| DsperseError::Other(\"component missing files\".into()))?\n            .iter()\n            .filter_map(|v| v.as_str().map(String::from))\n            .collect();\n\n        // Verify the component by probing each file the manifest\n        // expects to live in blob storage, not by asking for a\n        // metadata row.  The registry registers the component row\n        // as soon as POST /admin/components returns, but the\n        // per-file PUTs against the pre-signed upload URLs happen\n        // afterwards -- any failure there (timeout, network blip,\n        // interrupted publish process) leaves the row present with\n        // no backing files.  A plain GET /components/{sha} sees the\n        // row and reports \"exists, skipping\", so every subsequent\n        // publish run re-skips the broken component and the model\n        // stays permanently half-uploaded from downstream\n        // consumers' perspective.\n        //\n        // Mirror the byte-level presence check weight-blob uploads\n        // below already use: HEAD each expected file by issuing a\n        // single-byte ranged GET to the blob path.  If every file\n        // is present, the component is genuinely done and we skip.\n        // If every file is missing, proceed to the normal\n        // register + upload path.  If the set is partially present\n        // (registered but mid-upload), surface an actionable error\n        // instead of silently continuing, because re-registering\n        // via POST /admin/components will 409 and the current flow\n        // has no way to request fresh upload URLs for that sha.\n        // A manifest entry with zero files is malformed -- the\n        // empty-list case would otherwise make both\n        // `missing.is_empty()` and `present == files.len()` true\n        // below, silently classifying the component as present\n        // without any actual bytes verified.  Fail loud.\n        if files.is_empty() {\n            return Err(DsperseError::Other(format!(\n                \"component {sha} has no files listed in the manifest; refusing to treat as present\"\n            )));\n        }\n\n        let mut present = 0usize;\n        let mut missing: Vec<String> = Vec::new();\n        for filename in &files {\n            let file_url = format!(\"{api}/components/{sha}/files/{filename}\");\n            let probe = client\n                .get(&file_url)\n                .header(\"Range\", \"bytes=0-0\")\n                .send()\n                .await\n                .map_err(|e| DsperseError::Other(format!(\"probe {sha}/{filename}: {e}\")))?;\n            // A Range: bytes=0-0 GET against a blob path has two\n            // legitimate success replies: 206 (partial content,\n            // what the blob store returns when it honours the\n            // range) and 200 (full content, what it returns when\n            // it ignores the range for an empty body or tiny file).\n            // Any other 2xx (201 Created, 202 Accepted, 204 No\n            // Content) is ambiguous for a GET on a CAS path and\n            // should not be interpreted as \"file present\".\n            let status = probe.status();\n            match status.as_u16() {\n                200 | 206 => present += 1,\n                404 => missing.push(filename.clone()),\n                _ => {\n                    let text = probe.text().await.unwrap_or_default();\n                    return Err(DsperseError::Other(format!(\n                        \"probe component {sha}/{filename} returned unexpected status ({status}): {text}\"\n                    )));\n                }\n            }\n        }\n\n        if missing.is_empty() && present == files.len() {\n            tracing::info!(sha = %sha, \"component files present, skipping\");\n            components_skipped += 1;\n            continue;\n        }\n        if present > 0 {\n            return Err(DsperseError::Other(format!(\n                \"component {sha} is partially uploaded: {present}/{} files present, \\\n                 missing: {:?}.  A previous publish registered the component row but \\\n                 some PUTs did not complete.  Run \\\n                 `curl -X DELETE -H 'Authorization: Bearer $REGISTRY_AUTH_TOKEN' \\\n                 {api}/admin/components/{sha}` to drop the stale row, then re-run \\\n                 publish so the full register + upload flow can replay for this sha.\",\n                files.len(),\n                missing\n            )));\n        }\n\n        let proof_system = comp[\"proof_system\"]\n            .as_str()\n            .unwrap_or(&config.proof_system)\n            .to_uppercase();\n        let comp_name = comp[\"name\"].as_str().unwrap_or(sha);\n\n        tracing::info!(sha = %sha, files = files.len(), \"registering component\");\n        let register_resp = client\n            .post(format!(\"{api}/admin/components\"))\n            .header(\"Authorization\", &auth)\n            .json(&serde_json::json!({\n                \"sha256\": sha,\n                \"name\": comp_name,\n                \"description\": \"\",\n                \"proof_system\": proof_system,\n                \"files\": files,\n            }))\n            .send()\n            .await\n            .map_err(|e| DsperseError::Other(format!(\"register component {sha}: {e}\")))?;\n\n        let reg_status = register_resp.status();\n        if reg_status.as_u16() == 409 {\n            tracing::info!(sha = %sha, \"component already registered (conflict)\");\n            components_skipped += 1;\n            continue;\n        }\n        if !reg_status.is_success() {\n            let text = register_resp.text().await.unwrap_or_default();\n            if text.contains(\"already exists\") {\n                tracing::info!(sha = %sha, \"component already registered\");\n                components_skipped += 1;\n                continue;\n            }\n            return Err(DsperseError::Other(format!(\n                \"register component {sha} failed ({reg_status}): {text}\"\n            )));\n        }\n\n        let resp_body: serde_json::Value = register_resp\n            .json()\n            .await\n            .map_err(|e| DsperseError::Other(format!(\"parse component response: {e}\")))?;\n\n        let upload_urls = resp_body[\"upload_urls\"]\n            .as_object()\n            .ok_or_else(|| DsperseError::Other(\"missing upload_urls for component\".into()))?;\n\n        let comp_dir = dir.join(\"components\").join(sha);\n        for (filename, url_val) in upload_urls {\n            let url = url_val\n                .as_str()\n                .ok_or_else(|| DsperseError::Other(format!(\"non-string URL for {filename}\")))?;\n            let file_path = comp_dir.join(filename);\n            let data = fs::read(&file_path).map_err(|e| DsperseError::io(e, &file_path))?;\n\n            tracing::info!(file = %filename, size = data.len(), \"uploading component file\");\n            let put = client\n                .put(url)\n                .timeout(UPLOAD_TIMEOUT)\n                .header(\"Content-Type\", \"application/octet-stream\")\n                .body(data)\n                .send()\n                .await\n                .map_err(|e| DsperseError::Other(format!(\"upload {filename}: {e}\")))?;\n\n            if !put.status().is_success() {\n                return Err(DsperseError::Other(format!(\n                    \"upload component file {filename} failed ({})\",\n                    put.status()\n                )));\n            }\n        }\n\n        components_uploaded += 1;\n    }\n\n    let mut all_weight_refs: Vec<&serde_json::Value> = Vec::new();\n    if let Some(artifacts) = manifest[\"artifacts\"].as_array() {\n        all_weight_refs.extend(artifacts);\n    }\n    for comp in components {\n        if let Some(weights) = comp[\"weights\"].as_array() {\n            all_weight_refs.extend(weights);\n        }\n    }\n\n    let mut uploaded_wbs: std::collections::HashSet<String> = std::collections::HashSet::new();\n    for wref in &all_weight_refs {\n        let sha = wref[\"sha256\"]\n            .as_str()\n            .ok_or_else(|| DsperseError::Other(\"weight ref missing sha256\".into()))?;\n\n        if uploaded_wbs.contains(sha) {\n            continue;\n        }\n\n        let size = wref[\"size_bytes\"].as_u64().unwrap_or(0);\n\n        let check = client\n            .get(format!(\"{api}/models/wb/{sha}\"))\n            .header(\"Range\", \"bytes=0-0\")\n            .send()\n            .await\n            .map_err(|e| DsperseError::Other(format!(\"check wb {sha}: {e}\")))?;\n\n        if check.status().is_success() || check.status().as_u16() == 206 {\n            tracing::info!(sha = %sha, \"weight blob exists, skipping\");\n            weights_skipped += 1;\n            uploaded_wbs.insert(sha.to_string());\n            continue;\n        }\n        if check.status().as_u16() != 404 {\n            let status = check.status();\n            let text = check.text().await.unwrap_or_default();\n            return Err(DsperseError::Other(format!(\n                \"probe wb {sha} returned unexpected status ({status}): {text}\"\n            )));\n        }\n\n        let name = wref[\"role\"].as_str().unwrap_or(\"\");\n        tracing::info!(sha = %sha, size, \"registering weight blob\");\n        let wb_resp = client\n            .post(format!(\"{api}/admin/models/wb\"))\n            .header(\"Authorization\", &auth)\n            .json(&serde_json::json!({\n                \"sha256\": sha,\n                \"name\": name,\n                \"size_bytes\": size,\n            }))\n            .send()\n            .await\n            .map_err(|e| DsperseError::Other(format!(\"register wb {sha}: {e}\")))?;\n\n        let wb_status = wb_resp.status();\n        if wb_status.as_u16() == 409 {\n            tracing::info!(sha = %sha, \"weight blob already registered (conflict)\");\n            weights_skipped += 1;\n            uploaded_wbs.insert(sha.to_string());\n            continue;\n        }\n        if !wb_status.is_success() {\n            let text = wb_resp.text().await.unwrap_or_default();\n            if text.contains(\"already exists\") {\n                tracing::info!(sha = %sha, \"weight blob already registered\");\n                weights_skipped += 1;\n                uploaded_wbs.insert(sha.to_string());\n                continue;\n            }\n            return Err(DsperseError::Other(format!(\n                \"register wb {sha} failed ({wb_status}): {text}\"\n            )));\n        }\n\n        let wb_body: serde_json::Value = wb_resp\n            .json()\n            .await\n            .map_err(|e| DsperseError::Other(format!(\"parse wb response: {e}\")))?;\n\n        match wb_body[\"upload_url\"].as_str() {\n            Some(upload_url) => {\n                let wb_path = dir.join(\"wb\").join(sha);\n                let data = fs::read(&wb_path).map_err(|e| DsperseError::io(e, &wb_path))?;\n\n                tracing::info!(sha = %sha, size = data.len(), \"uploading weight blob\");\n                let put = client\n                    .put(upload_url)\n                    .timeout(UPLOAD_TIMEOUT)\n                    .header(\"Content-Type\", \"application/octet-stream\")\n                    .body(data)\n                    .send()\n                    .await\n                    .map_err(|e| DsperseError::Other(format!(\"upload wb {sha}: {e}\")))?;\n\n                if !put.status().is_success() {\n                    return Err(DsperseError::Other(format!(\n                        \"upload wb {sha} failed ({})\",\n                        put.status()\n                    )));\n                }\n                weights_uploaded += 1;\n            }\n            None => {\n                return Err(DsperseError::Other(format!(\n                    \"registry returned no upload URL for weight blob {sha}\"\n                )));\n            }\n        }\n\n        uploaded_wbs.insert(sha.to_string());\n    }\n\n    let model_info = &manifest[\"model\"];\n    let model_name = model_info[\"name\"].as_str().unwrap_or(config.name.as_str());\n    let model_author = model_info[\"author\"]\n        .as_str()\n        .unwrap_or(config.author.as_str());\n    let model_version = model_info[\"version\"]\n        .as_str()\n        .unwrap_or(config.version.as_str());\n    let model_timeout = model_info[\"timeout\"].as_u64().unwrap_or(config.timeout);\n    let input_schema = &model_info[\"input_schema\"];\n    let dsperse_version = model_info[\"dsperse_version\"].as_str();\n    let jstprove_version = model_info[\"jstprove_version\"].as_str();\n\n    let artifacts = manifest[\"artifacts\"]\n        .as_array()\n        .cloned()\n        .unwrap_or_default();\n    let composition = serde_json::json!({\n        \"version\": 1,\n        \"artifacts\": artifacts,\n        \"components\": components,\n        \"dag\": dag,\n    });\n\n    let mut model_hasher = Sha256::new();\n    model_hasher.update(model_name.as_bytes());\n    model_hasher.update(b\"\\x00\");\n    model_hasher.update(model_author.as_bytes());\n    model_hasher.update(b\"\\x00\");\n    model_hasher.update(model_version.as_bytes());\n    model_hasher.update(b\"\\x00\");\n    model_hasher.update(model_timeout.to_le_bytes());\n    model_hasher.update(b\"\\x00\");\n    let comp_json = serde_json::to_string(&composition)\n        .map_err(|e| DsperseError::Other(format!(\"serialize composition: {e}\")))?;\n    model_hasher.update(comp_json.as_bytes());\n    let model_id = format!(\"{:x}\", model_hasher.finalize());\n\n    tracing::info!(id = %model_id, \"creating model\");\n    let model_resp = client\n        .post(format!(\"{api}/admin/models\"))\n        .header(\"Authorization\", &auth)\n        .json(&serde_json::json!({\n            \"id\": model_id,\n            \"metadata\": {\n                \"name\": model_name,\n                \"description\": config.description,\n                \"author\": model_author,\n                \"version\": model_version,\n                \"netuid\": null,\n                \"weights_version\": null,\n                \"timeout\": model_timeout,\n                \"input_schema\": input_schema,\n                \"dsperse_version\": dsperse_version,\n                \"jstprove_version\": jstprove_version,\n            },\n            \"composition\": composition,\n        }))\n        .send()\n        .await\n        .map_err(|e| DsperseError::Other(format!(\"create model: {e}\")))?;\n\n    if !model_resp.status().is_success() {\n        let status = model_resp.status();\n        let text = model_resp.text().await.unwrap_or_default();\n        if !text.contains(\"already exists\") {\n            return Err(DsperseError::Other(format!(\n                \"create model failed ({status}): {text}\"\n            )));\n        }\n        tracing::info!(id = %model_id, \"model already exists\");\n    }\n\n    if config.activate {\n        tracing::info!(id = %model_id, \"activating model\");\n        let activate_resp = client\n            .patch(format!(\"{api}/admin/models/{model_id}\"))\n            .header(\"Authorization\", &auth)\n            .json(&serde_json::json!({ \"is_active\": true }))\n            .send()\n            .await\n            .map_err(|e| DsperseError::Other(format!(\"activate: {e}\")))?;\n\n        if !activate_resp.status().is_success() {\n            let status = activate_resp.status();\n            let text = activate_resp.text().await.unwrap_or_default();\n            return Err(DsperseError::Other(format!(\n                \"activate failed ({status}): {text}\"\n            )));\n        }\n    }\n\n    Ok(PublishResult {\n        model_id,\n        components_uploaded,\n        components_skipped,\n        weights_uploaded,\n        weights_skipped,\n    })\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/runner.rs",
    "content": "use std::collections::HashMap;\nuse std::path::{Path, PathBuf};\n\nuse ndarray::{ArrayD, IxDyn};\n\nuse jstprove_circuits::api::CircuitParamsType as CircuitParams;\n\nuse super::strategy::ExecutionStrategy;\nuse super::tensor_store::TensorStore;\nuse crate::backend::jstprove::JstproveBackend;\nuse crate::backend::onnx::NamedOutputs;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::execution::{\n    ExecutionChain, ExecutionInfo, ExecutionMethod, ExecutionNode, ExecutionResultEntry,\n    RunMetadata,\n};\nuse crate::schema::metadata::{BackendKind, ModelMetadata, RunSliceMetadata};\nuse crate::slicer::onnx_proto::TensorProto;\nuse crate::utils::io::{\n    arrayd_to_value, build_msgpack_map, extract_input_data, map_get_ref, read_msgpack,\n    value_to_arrayd, write_msgpack,\n};\nuse crate::utils::paths::{find_metadata_path, resolve_relative_path, slice_dir_path};\nuse rmpv::Value;\n\npub struct RunConfig {\n    pub parallel: usize,\n    pub batch: bool,\n    pub weights_onnx: Option<PathBuf>,\n    pub combined: bool,\n}\n\nimpl Default for RunConfig {\n    fn default() -> Self {\n        Self {\n            parallel: 1,\n            batch: false,\n            weights_onnx: None,\n            combined: true,\n        }\n    }\n}\n\nfn resolve_circuit_path_required(\n    slices_dir: &Path,\n    circuit_path: Option<&str>,\n    label: &str,\n) -> Result<PathBuf> {\n    circuit_path\n        .map(|p| resolve_relative_path(slices_dir, p))\n        .transpose()?\n        .ok_or_else(|| DsperseError::Pipeline(format!(\"no circuit path for {label}\")))\n}\n\npub(crate) fn resolve_circuit_path_optional(\n    slices_dir: &Path,\n    circuit_path: Option<&str>,\n) -> Result<Option<PathBuf>> {\n    circuit_path\n        .map(|p| resolve_relative_path(slices_dir, p))\n        .transpose()\n}\n\npub fn load_model_metadata(slices_dir: &Path) -> Result<ModelMetadata> {\n    let meta_path = find_metadata_path(slices_dir).ok_or_else(|| {\n        DsperseError::Metadata(format!(\n            \"no {} in slices\",\n            crate::utils::paths::METADATA_FILE\n        ))\n    })?;\n    let mut model_meta = ModelMetadata::load(&meta_path)?;\n\n    if model_meta.slices.is_empty() {\n        return Err(DsperseError::Metadata(format!(\n            \"{} has no slices in {}\",\n            crate::utils::paths::METADATA_FILE,\n            slices_dir.display()\n        )));\n    }\n\n    model_meta.slices.sort_by_key(|s| s.index);\n\n    Ok(model_meta)\n}\n\nfn validate_weights_onnx(\n    donor_init_map: &HashMap<String, &TensorProto>,\n    model_meta: &ModelMetadata,\n    slices_dir: &Path,\n) -> Result<()> {\n    for slice in &model_meta.slices {\n        let onnx_path = slice.resolve_onnx(slices_dir)?;\n        if !onnx_path.exists() {\n            return Err(DsperseError::Pipeline(format!(\n                \"slice_{} ONNX not found at {}\",\n                slice.index,\n                onnx_path.display()\n            )));\n        }\n        let slice_model = crate::slicer::onnx_proto::load_model(&onnx_path)?;\n        let slice_graph = slice_model.graph.as_ref().ok_or_else(|| {\n            DsperseError::Pipeline(format!(\n                \"slice_{} ONNX at {} has no graph\",\n                slice.index,\n                onnx_path.display()\n            ))\n        })?;\n        let context = format!(\"slice_{}\", slice.index);\n        crate::slicer::onnx_proto::validate_initializer_compatibility(\n            &slice_graph.initializer,\n            donor_init_map,\n            &context,\n        )?;\n    }\n    Ok(())\n}\n\nfn load_donor_model(\n    weights_onnx: Option<&PathBuf>,\n) -> Result<Option<crate::slicer::onnx_proto::ModelProto>> {\n    let weights_path = match weights_onnx {\n        Some(p) => p,\n        None => return Ok(None),\n    };\n    if !weights_path.is_file() {\n        return Err(DsperseError::Other(format!(\n            \"consumer weights ONNX not found: {}\",\n            weights_path.display()\n        )));\n    }\n    Ok(Some(crate::slicer::onnx_proto::load_model(weights_path)?))\n}\n\nfn donor_init_map(\n    model: Option<&crate::slicer::onnx_proto::ModelProto>,\n) -> Result<Option<HashMap<String, &TensorProto>>> {\n    match model {\n        Some(m) => {\n            let graph = m.graph.as_ref().ok_or_else(|| {\n                DsperseError::Pipeline(\"consumer weights ONNX missing graph\".into())\n            })?;\n            Ok(Some(crate::slicer::onnx_proto::build_initializer_map(\n                graph,\n            )))\n        }\n        None => Ok(None),\n    }\n}\n\npub fn run_inference(\n    slices_dir: &Path,\n    input_path: &Path,\n    run_dir: &Path,\n    backend: &JstproveBackend,\n    config: &RunConfig,\n) -> Result<RunMetadata> {\n    let model_meta = load_model_metadata(slices_dir)?;\n\n    if config.combined\n        && model_meta.original_model_path.is_some()\n        && model_meta.traced_shapes.is_some()\n    {\n        return run_combined_inference(\n            slices_dir,\n            input_path,\n            run_dir,\n            backend,\n            config,\n            &model_meta,\n        );\n    } else if config.combined {\n        tracing::warn!(\n            \"combined mode requested but metadata missing original_model_path or traced_shapes, using per-slice execution\"\n        );\n    }\n\n    if model_meta.original_model_path.is_some() {\n        crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, &model_meta)?;\n    }\n\n    let donor_model = load_donor_model(config.weights_onnx.as_ref())?;\n    let donor_map = donor_init_map(donor_model.as_ref())?;\n    if let Some(ref map) = donor_map {\n        validate_weights_onnx(map, &model_meta, slices_dir)?;\n        tracing::info!(\n            weights = %config.weights_onnx.as_ref().unwrap().display(),\n            \"validated consumer weights ONNX\"\n        );\n    }\n\n    std::fs::create_dir_all(run_dir).map_err(|e| DsperseError::io(e, run_dir))?;\n\n    let input_data = read_msgpack(input_path)?;\n\n    let chain = build_execution_chain(&model_meta, slices_dir)?;\n    let run_meta = build_run_metadata(&model_meta, slices_dir, &chain)?;\n\n    let mut tensor_cache = TensorStore::new();\n\n    let input_val = extract_input_data(&input_data).ok_or_else(|| {\n        DsperseError::Pipeline(\n            \"input has no recognized input key (input_data, input, data, inputs)\".into(),\n        )\n    })?;\n    let first_slice = model_meta\n        .slices\n        .first()\n        .ok_or_else(|| DsperseError::Pipeline(\"model has no slices\".into()))?;\n    let declared_inputs = &first_slice.dependencies.filtered_inputs;\n    if declared_inputs.is_empty() {\n        return Err(DsperseError::Pipeline(\n            \"first slice has no input dependency\".into(),\n        ));\n    }\n    if input_val.is_map() {\n        for name in declared_inputs {\n            let v = map_get_ref(input_val, name)\n                .ok_or_else(|| DsperseError::Pipeline(format!(\"input map missing key {name:?}\")))?;\n            tensor_cache.put(name.clone(), value_to_arrayd(v)?);\n        }\n    } else if declared_inputs.len() == 1 {\n        tensor_cache.put(declared_inputs[0].clone(), value_to_arrayd(input_val)?);\n    } else {\n        return Err(DsperseError::Pipeline(format!(\n            \"model declares {} inputs but input is not a map\",\n            declared_inputs.len()\n        )));\n    }\n\n    let input_copy = run_dir.join(crate::utils::paths::INPUT_FILE);\n    write_msgpack(&input_copy, &input_data)?;\n\n    let mut results: Vec<ExecutionResultEntry> = Vec::new();\n\n    let mut current = chain.head.clone();\n    while let Some(slice_id) = current.take() {\n        let node = chain\n            .nodes\n            .get(&slice_id)\n            .ok_or_else(|| DsperseError::Pipeline(format!(\"missing node {slice_id}\")))?;\n\n        let slice_meta = run_meta.slices.get(&slice_id).ok_or_else(|| {\n            DsperseError::Pipeline(format!(\"missing run slice metadata {slice_id}\"))\n        })?;\n\n        let slice_run_dir = run_dir.join(&slice_id);\n        std::fs::create_dir_all(&slice_run_dir).map_err(|e| DsperseError::io(e, &slice_run_dir))?;\n\n        tracing::info!(slice = %slice_id, circuit = node.use_circuit, \"executing\");\n\n        let exec_result = execute_slice(\n            slices_dir,\n            &slice_run_dir,\n            &slice_id,\n            node,\n            slice_meta,\n            &mut tensor_cache,\n            backend,\n            config,\n            donor_map.as_ref(),\n        );\n\n        let exec_info = match exec_result {\n            Ok(info) => info,\n            Err(e) => {\n                tracing::error!(slice = %slice_id, error = %e, \"execution failed\");\n                let method = ExecutionStrategy::from_metadata(slice_meta, node.use_circuit)\n                    .map(|s| s.execution_method())\n                    .unwrap_or(ExecutionMethod::OnnxOnly);\n                results.push(ExecutionResultEntry {\n                    slice_id: slice_id.clone(),\n                    witness_execution: Some(ExecutionInfo {\n                        method,\n                        success: false,\n                        error: Some(e.to_string()),\n                        witness_file: None,\n                        tile_exec_infos: Vec::new(),\n                    }),\n                    proof_execution: None,\n                    verification_execution: None,\n                });\n                break;\n            }\n        };\n\n        results.push(ExecutionResultEntry {\n            slice_id: slice_id.clone(),\n            witness_execution: Some(exec_info),\n            proof_execution: None,\n            verification_execution: None,\n        });\n\n        current = node.next.clone();\n    }\n\n    let mut final_meta = run_meta;\n    final_meta.execution_chain.execution_results = results;\n    final_meta.run_directory = Some(run_dir.to_string_lossy().into_owned());\n\n    let meta_out = run_dir.join(crate::utils::paths::METADATA_FILE);\n    crate::utils::metadata::save_run_metadata(&meta_out, &final_meta)?;\n\n    let last_slice = model_meta\n        .slices\n        .last()\n        .ok_or_else(|| DsperseError::Pipeline(\"model has no slices\".into()))?;\n    let last_slice_id = format!(\"slice_{}\", last_slice.index);\n    if let Some(failed) = final_meta\n        .execution_chain\n        .execution_results\n        .iter()\n        .find(|r| r.witness_execution.as_ref().is_some_and(|w| !w.success))\n    {\n        let err_msg = failed\n            .witness_execution\n            .as_ref()\n            .and_then(|w| w.error.as_deref())\n            .unwrap_or(\"unknown\");\n        return Err(DsperseError::Pipeline(format!(\n            \"pipeline failed at {}: {err_msg}\",\n            failed.slice_id\n        )));\n    }\n\n    let slice_run_meta = final_meta.slices.get(&last_slice_id);\n    let last_strategy = match slice_run_meta {\n        Some(m) => {\n            let use_circuit = final_meta\n                .execution_chain\n                .nodes\n                .get(&last_slice_id)\n                .is_some_and(|n| n.use_circuit);\n            ExecutionStrategy::from_metadata(m, use_circuit).ok()\n        }\n        None => None,\n    };\n    let output_arrs: Vec<&ArrayD<f64>> = {\n        let strategy_output = last_strategy\n            .as_ref()\n            .and_then(|s| s.output_name())\n            .and_then(|name| tensor_cache.try_get(name));\n        if let Some(arr) = strategy_output {\n            vec![arr]\n        } else if !model_meta.output_names.is_empty() {\n            let found: Vec<_> = model_meta\n                .output_names\n                .iter()\n                .filter_map(|n| tensor_cache.try_get(n))\n                .collect();\n            if found.is_empty() {\n                tracing::warn!(\n                    expected = ?model_meta.output_names,\n                    available = ?tensor_cache.keys().collect::<Vec<_>>(),\n                    \"none of the declared output_names found in tensor cache\"\n                );\n            }\n            found\n        } else {\n            last_slice\n                .dependencies\n                .output\n                .iter()\n                .find_map(|n| tensor_cache.try_get(n))\n                .into_iter()\n                .collect()\n        }\n    };\n    if output_arrs.is_empty() {\n        let first_error = final_meta\n            .execution_chain\n            .execution_results\n            .iter()\n            .filter_map(|r| {\n                r.witness_execution\n                    .as_ref()\n                    .and_then(|w| w.error.as_deref())\n                    .map(|err| format!(\"{}: {err}\", r.slice_id))\n            })\n            .next();\n        return Err(match first_error {\n            Some(err) => DsperseError::Pipeline(format!(\"pipeline failed at {err}\")),\n            None => DsperseError::Pipeline(format!(\n                \"no output tensor found for last slice {last_slice_id}\"\n            )),\n        });\n    }\n    let output_path = run_dir.join(crate::utils::paths::OUTPUT_FILE);\n    let output_val = Value::Array(output_arrs.iter().map(|arr| arrayd_to_value(arr)).collect());\n    write_msgpack(\n        &output_path,\n        &build_msgpack_map(vec![(\"output_data\", output_val)]),\n    )?;\n\n    Ok(final_meta)\n}\n\nfn run_combined_inference(\n    slices_dir: &Path,\n    input_path: &Path,\n    run_dir: &Path,\n    backend: &JstproveBackend,\n    config: &RunConfig,\n    model_meta: &ModelMetadata,\n) -> Result<RunMetadata> {\n    let combined_path =\n        crate::slicer::combiner::ensure_combined_materialized(slices_dir, model_meta)?;\n\n    let donor_model = load_donor_model(config.weights_onnx.as_ref())?;\n    let donor_map = donor_init_map(donor_model.as_ref())?;\n    if let Some(ref map) = donor_map {\n        let combined_model = crate::slicer::onnx_proto::load_model(&combined_path)?;\n        let combined_graph = combined_model\n            .graph\n            .as_ref()\n            .ok_or_else(|| DsperseError::Pipeline(\"combined ONNX missing graph\".into()))?;\n        crate::slicer::onnx_proto::validate_initializer_compatibility(\n            &combined_graph.initializer,\n            map,\n            \"combined\",\n        )?;\n        tracing::info!(\n            weights = %config.weights_onnx.as_ref().unwrap().display(),\n            \"validated consumer weights against combined ONNX\"\n        );\n    }\n\n    std::fs::create_dir_all(run_dir).map_err(|e| DsperseError::io(e, run_dir))?;\n\n    let input_data = read_msgpack(input_path)?;\n    let input_val = extract_input_data(&input_data).ok_or_else(|| {\n        DsperseError::Pipeline(\n            \"input has no recognized input key (input_data, input, data, inputs)\".into(),\n        )\n    })?;\n    let first_slice = model_meta\n        .slices\n        .first()\n        .ok_or_else(|| DsperseError::Pipeline(\"model has no slices\".into()))?;\n    let declared_inputs = &first_slice.dependencies.filtered_inputs;\n    if declared_inputs.is_empty() {\n        return Err(DsperseError::Pipeline(\n            \"first slice has no input dependency\".into(),\n        ));\n    }\n\n    let input_copy = run_dir.join(crate::utils::paths::INPUT_FILE);\n    write_msgpack(&input_copy, &input_data)?;\n\n    let effective_combined = if let Some(ref map) = donor_map {\n        Some(crate::slicer::onnx_proto::build_patched_onnx(\n            &combined_path,\n            map,\n        )?)\n    } else {\n        None\n    };\n    let effective_path = effective_combined\n        .as_ref()\n        .map_or(combined_path.as_path(), |t| t.path());\n\n    let named_outputs = if input_val.is_map() {\n        let mut cache = TensorStore::new();\n        for name in declared_inputs {\n            let v = map_get_ref(input_val, name)\n                .ok_or_else(|| DsperseError::Pipeline(format!(\"input map missing key {name:?}\")))?;\n            cache.put(name.clone(), value_to_arrayd(v)?);\n        }\n        let inputs: Vec<String> = declared_inputs.clone();\n        run_onnx_inference_multi_named(effective_path, &cache, &inputs)?\n    } else if declared_inputs.len() == 1 {\n        let input_arr = value_to_arrayd(input_val)?;\n        run_onnx_inference_named(effective_path, &input_arr)?\n    } else {\n        return Err(DsperseError::Pipeline(format!(\n            \"model declares {} inputs but input is not a map\",\n            declared_inputs.len()\n        )));\n    };\n\n    tracing::info!(\n        outputs = named_outputs.len(),\n        \"combined model inference complete\"\n    );\n\n    let mut tensor_cache = TensorStore::new();\n    for (name, (data, shape)) in &named_outputs {\n        let arr = ArrayD::from_shape_vec(IxDyn(shape), data.clone())\n            .map_err(|e| DsperseError::Pipeline(format!(\"output reshape '{name}': {e}\")))?;\n        tensor_cache.put(name.clone(), arr);\n    }\n\n    for name in declared_inputs {\n        if !tensor_cache.contains(name) {\n            if input_val.is_map() {\n                let v = map_get_ref(input_val, name).ok_or_else(|| {\n                    DsperseError::Pipeline(format!(\n                        \"combined fallback: input map missing key {name:?}\"\n                    ))\n                })?;\n                tensor_cache.put(name.clone(), value_to_arrayd(v)?);\n            } else if declared_inputs.len() == 1 {\n                tensor_cache.put(name.clone(), value_to_arrayd(input_val)?);\n            }\n        }\n    }\n\n    crate::slicer::materializer::ensure_all_slices_materialized(slices_dir, model_meta)?;\n    let chain = build_execution_chain(model_meta, slices_dir)?;\n    let run_meta = build_run_metadata(model_meta, slices_dir, &chain)?;\n\n    let mut results: Vec<ExecutionResultEntry> = Vec::new();\n\n    for slice in &model_meta.slices {\n        let slice_id = format!(\"slice_{}\", slice.index);\n        let node = chain\n            .nodes\n            .get(&slice_id)\n            .ok_or_else(|| DsperseError::Pipeline(format!(\"missing node {slice_id}\")))?;\n\n        let slice_meta = run_meta.slices.get(&slice_id).ok_or_else(|| {\n            DsperseError::Pipeline(format!(\"missing run slice metadata {slice_id}\"))\n        })?;\n\n        let slice_run_dir = run_dir.join(&slice_id);\n        std::fs::create_dir_all(&slice_run_dir).map_err(|e| DsperseError::io(e, &slice_run_dir))?;\n\n        if !node.use_circuit {\n            results.push(ExecutionResultEntry {\n                slice_id: slice_id.clone(),\n                witness_execution: Some(ExecutionInfo {\n                    method: ExecutionMethod::OnnxOnly,\n                    success: true,\n                    error: None,\n                    witness_file: None,\n                    tile_exec_infos: Vec::new(),\n                }),\n                proof_execution: None,\n                verification_execution: None,\n            });\n            continue;\n        }\n\n        let strategy = ExecutionStrategy::from_metadata(slice_meta, node.use_circuit)?;\n\n        if let ExecutionStrategy::ChannelSplit(_) = &strategy {\n            return Err(DsperseError::Pipeline(format!(\n                \"{slice_id}: combined mode does not support channel-split circuit slices; use --combined false\"\n            )));\n        }\n\n        if let ExecutionStrategy::DimSplit(_) = &strategy {\n            return Err(DsperseError::Pipeline(format!(\n                \"{slice_id}: combined mode does not support dim-split circuit slices; use --combined false\"\n            )));\n        }\n\n        if let ExecutionStrategy::Tiled(tiling) = &strategy {\n            let result = super::tiled::execute_combined_tiled(\n                slices_dir,\n                &slice_run_dir,\n                &slice_id,\n                tiling,\n                slice_meta.jstprove_circuit_path.as_deref(),\n                &tensor_cache,\n                backend,\n                config,\n                donor_map.as_ref(),\n            )?;\n            for (name, tensor) in result.outputs {\n                tensor_cache.put(name, tensor);\n            }\n\n            let success = result.info.success;\n            results.push(ExecutionResultEntry {\n                slice_id: slice_id.clone(),\n                witness_execution: Some(result.info),\n                proof_execution: None,\n                verification_execution: None,\n            });\n\n            if !success {\n                break;\n            }\n            continue;\n        }\n\n        let circuit_path = resolve_circuit_path_required(\n            slices_dir,\n            slice_meta.jstprove_circuit_path.as_deref(),\n            &slice_id,\n        )?;\n\n        let params = backend.load_params(&circuit_path)?;\n        let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs);\n\n        if donor_map.is_some() && !is_wai {\n            return Err(DsperseError::Pipeline(format!(\n                \"{slice_id}: consumer weights require circuits compiled with --weights-as-inputs\"\n            )));\n        }\n\n        let activation_inputs: Vec<String> = slice\n            .dependencies\n            .filtered_inputs\n            .iter()\n            .filter(|s| !s.is_empty())\n            .cloned()\n            .collect();\n\n        let witness_result = if activation_inputs.is_empty() {\n            Err(DsperseError::Pipeline(format!(\n                \"{slice_id}: no activation inputs declared for circuit slice\"\n            )))\n        } else {\n            let mut flat_activations: Vec<f64> = Vec::new();\n            for input_name in &activation_inputs {\n                let input_arr = tensor_cache.get(input_name).map_err(|_| {\n                    DsperseError::Pipeline(format!(\n                        \"{slice_id}: activation input '{input_name}' not found in combined model outputs\"\n                    ))\n                })?;\n                flat_activations.extend(input_arr.iter());\n            }\n\n            if is_wai {\n                let onnx_path = slice.resolve_onnx(slices_dir)?;\n                let initializers = if let Some(donor) = donor_map.as_ref() {\n                    let slice_model = crate::slicer::onnx_proto::load_model(&onnx_path)?;\n                    let slice_graph = slice_model.graph.as_ref().ok_or_else(|| {\n                        DsperseError::Pipeline(format!(\"{slice_id}: ONNX missing graph\"))\n                    })?;\n                    let mut merged = crate::slicer::onnx_proto::build_initializer_map(slice_graph);\n                    for (k, v) in donor.iter() {\n                        merged.insert(k.clone(), *v);\n                    }\n                    extract_initializers_from_map(&merged, params.as_ref().unwrap())?\n                } else {\n                    extract_onnx_initializers(&onnx_path, params.as_ref().unwrap())?\n                };\n                backend.witness_f64(&circuit_path, &flat_activations, &initializers)\n            } else {\n                backend.witness_f64(&circuit_path, &flat_activations, &[])\n            }\n        };\n\n        match witness_result {\n            Ok(witness_bytes) => {\n                let witness_path = slice_run_dir.join(crate::utils::paths::WITNESS_FILE);\n                std::fs::write(&witness_path, &witness_bytes)\n                    .map_err(|e| DsperseError::io(e, &witness_path))?;\n\n                tracing::info!(slice = %slice_id, \"witness generated from combined outputs\");\n\n                results.push(ExecutionResultEntry {\n                    slice_id: slice_id.clone(),\n                    witness_execution: Some(ExecutionInfo {\n                        method: ExecutionMethod::JstproveGenWitness,\n                        success: true,\n                        error: None,\n                        witness_file: Some(witness_path.to_string_lossy().into_owned()),\n                        tile_exec_infos: Vec::new(),\n                    }),\n                    proof_execution: None,\n                    verification_execution: None,\n                });\n            }\n            Err(e) => {\n                tracing::error!(slice = %slice_id, error = %e, \"witness generation failed\");\n                results.push(ExecutionResultEntry {\n                    slice_id: slice_id.clone(),\n                    witness_execution: Some(ExecutionInfo {\n                        method: ExecutionMethod::JstproveGenWitness,\n                        success: false,\n                        error: Some(e.to_string()),\n                        witness_file: None,\n                        tile_exec_infos: Vec::new(),\n                    }),\n                    proof_execution: None,\n                    verification_execution: None,\n                });\n                break;\n            }\n        }\n    }\n\n    let mut final_meta = run_meta;\n    final_meta.execution_chain.execution_results = results;\n    final_meta.run_directory = Some(run_dir.to_string_lossy().into_owned());\n\n    let witness_failure = final_meta\n        .execution_chain\n        .execution_results\n        .iter()\n        .filter_map(|r| {\n            r.witness_execution\n                .as_ref()\n                .filter(|w| !w.success)\n                .and_then(|w| w.error.as_ref())\n                .map(|err| format!(\"{}: {err}\", r.slice_id))\n        })\n        .next();\n    if let Some(err) = witness_failure {\n        let meta_out = run_dir.join(crate::utils::paths::METADATA_FILE);\n        let _ = crate::utils::metadata::save_run_metadata(&meta_out, &final_meta);\n        return Err(DsperseError::Pipeline(format!(\n            \"combined pipeline failed at {err}\"\n        )));\n    }\n\n    let meta_out = run_dir.join(crate::utils::paths::METADATA_FILE);\n    crate::utils::metadata::save_run_metadata(&meta_out, &final_meta)?;\n\n    let last_slice = model_meta\n        .slices\n        .last()\n        .ok_or_else(|| DsperseError::Pipeline(\"model has no slices\".into()))?;\n    let output_arrs: Vec<&ArrayD<f64>> = if !model_meta.output_names.is_empty() {\n        model_meta\n            .output_names\n            .iter()\n            .filter_map(|n| tensor_cache.try_get(n))\n            .collect()\n    } else {\n        last_slice\n            .dependencies\n            .output\n            .iter()\n            .find_map(|n| tensor_cache.try_get(n))\n            .into_iter()\n            .collect()\n    };\n\n    if output_arrs.is_empty() {\n        let expected: Vec<&str> = if !model_meta.output_names.is_empty() {\n            model_meta.output_names.iter().map(String::as_str).collect()\n        } else {\n            last_slice\n                .dependencies\n                .output\n                .iter()\n                .map(String::as_str)\n                .collect()\n        };\n        let available: Vec<&String> = tensor_cache.keys().collect();\n        return Err(DsperseError::Pipeline(format!(\n            \"no output tensor found in combined model outputs; expected {expected:?}, available {available:?}\"\n        )));\n    }\n\n    let output_path = run_dir.join(crate::utils::paths::OUTPUT_FILE);\n    let output_val = Value::Array(output_arrs.iter().map(|arr| arrayd_to_value(arr)).collect());\n    write_msgpack(\n        &output_path,\n        &build_msgpack_map(vec![(\"output_data\", output_val)]),\n    )?;\n\n    tracing::info!(\n        run_dir = %run_dir.display(),\n        slices = model_meta.slices.len(),\n        \"combined inference complete\"\n    );\n\n    Ok(final_meta)\n}\n\n#[allow(clippy::too_many_arguments)]\nfn execute_slice(\n    slices_dir: &Path,\n    slice_run_dir: &Path,\n    slice_id: &str,\n    node: &ExecutionNode,\n    meta: &RunSliceMetadata,\n    tensor_cache: &mut TensorStore,\n    backend: &JstproveBackend,\n    config: &RunConfig,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<ExecutionInfo> {\n    let strategy = ExecutionStrategy::from_metadata(meta, node.use_circuit)?;\n    match strategy {\n        ExecutionStrategy::ChannelSplit(cs) => {\n            let target_shape = meta\n                .dependencies\n                .output\n                .iter()\n                .position(|name| name == &cs.output_name)\n                .and_then(|idx| meta.output_shape.get(idx))\n                .map(|v| v.as_slice());\n            if target_shape.is_none() {\n                tracing::debug!(\n                    slice = %slice_id,\n                    output_name = %cs.output_name,\n                    \"target_shape lookup failed; output will not be reshaped\"\n                );\n            }\n            let result = super::channel_split::execute_channel_split(\n                slices_dir,\n                slice_run_dir,\n                slice_id,\n                cs,\n                target_shape,\n                tensor_cache,\n                backend,\n                donor_init_map,\n            )?;\n            for (name, tensor) in result.outputs {\n                tensor_cache.put(name, tensor);\n            }\n            Ok(result.info)\n        }\n        ExecutionStrategy::Tiled(tiling) => {\n            let slice_circuit =\n                resolve_circuit_path_optional(slices_dir, meta.jstprove_circuit_path.as_deref())?;\n            let result = super::tiled::execute_tiled(\n                slices_dir,\n                slice_run_dir,\n                slice_id,\n                tiling,\n                slice_circuit.as_deref(),\n                tensor_cache,\n                backend,\n                config,\n                donor_init_map,\n            )?;\n            for (name, tensor) in result.outputs {\n                tensor_cache.put(name, tensor);\n            }\n            Ok(result.info)\n        }\n        ExecutionStrategy::DimSplit(ds) => {\n            let target_shape = meta\n                .dependencies\n                .output\n                .iter()\n                .position(|name| name == &ds.output_name)\n                .and_then(|idx| meta.output_shape.get(idx))\n                .map(|v| v.as_slice());\n            let result = super::dim_split::execute_dim_split(\n                slices_dir,\n                slice_run_dir,\n                slice_id,\n                ds,\n                target_shape,\n                tensor_cache,\n                backend,\n                donor_init_map,\n            )?;\n            for (name, tensor) in result.outputs {\n                tensor_cache.put(name, tensor);\n            }\n            Ok(result.info)\n        }\n        ExecutionStrategy::Single { .. } => {\n            let result = execute_single(\n                slices_dir,\n                slice_run_dir,\n                slice_id,\n                node,\n                meta,\n                tensor_cache,\n                backend,\n                donor_init_map,\n            )?;\n            for (name, tensor) in result.outputs {\n                tensor_cache.put(name, tensor);\n            }\n            Ok(result.info)\n        }\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\nfn execute_single(\n    slices_dir: &Path,\n    slice_run_dir: &Path,\n    slice_id: &str,\n    node: &ExecutionNode,\n    meta: &RunSliceMetadata,\n    tensor_cache: &TensorStore,\n    backend: &JstproveBackend,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<crate::schema::execution::StrategyOutput> {\n    let inputs: Vec<String> = meta\n        .dependencies\n        .filtered_inputs\n        .iter()\n        .filter(|s| !s.is_empty())\n        .cloned()\n        .collect();\n    let multi_input = inputs.len() > 1;\n\n    if inputs.is_empty() {\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: no activation inputs declared\"\n        )));\n    }\n\n    let onnx_path = PathBuf::from(&meta.path);\n\n    let patched_onnx = if let Some(map) = donor_init_map {\n        Some(crate::slicer::onnx_proto::build_patched_onnx(\n            &onnx_path, map,\n        )?)\n    } else {\n        None\n    };\n    let effective_onnx: &Path = patched_onnx\n        .as_ref()\n        .map_or(onnx_path.as_path(), |t| t.path());\n\n    if node.use_circuit {\n        let circuit_path = resolve_circuit_path_required(\n            slices_dir,\n            meta.jstprove_circuit_path.as_deref(),\n            slice_id,\n        )?;\n\n        let params = backend.load_params(&circuit_path)?;\n        let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs);\n\n        if donor_init_map.is_some() && !is_wai {\n            return Err(DsperseError::Pipeline(format!(\n                \"{slice_id}: consumer weights require circuits compiled with --weights-as-inputs\"\n            )));\n        }\n\n        let named = if multi_input {\n            run_onnx_inference_multi_named(effective_onnx, tensor_cache, &inputs)?\n        } else {\n            let input_tensor = tensor_cache.gather(&inputs[..1])?;\n            run_onnx_inference_named(effective_onnx, &input_tensor)?\n        };\n\n        let outputs = collect_named_outputs(&meta.dependencies.output, named)?;\n\n        let flat_activations = flatten_cached_inputs(tensor_cache, &inputs)?;\n        let witness_bytes = if is_wai {\n            generate_wai_witness(\n                backend,\n                &circuit_path,\n                &onnx_path,\n                donor_init_map,\n                params.as_ref().unwrap(),\n                &flat_activations,\n            )?\n        } else {\n            backend.witness_f64(&circuit_path, &flat_activations, &[])?\n        };\n\n        let witness_path = slice_run_dir.join(crate::utils::paths::WITNESS_FILE);\n        std::fs::write(&witness_path, &witness_bytes)\n            .map_err(|e| DsperseError::io(e, &witness_path))?;\n\n        Ok(crate::schema::execution::StrategyOutput {\n            info: ExecutionInfo {\n                method: ExecutionMethod::JstproveGenWitness,\n                success: true,\n                error: None,\n                witness_file: Some(witness_path.to_string_lossy().into_owned()),\n                tile_exec_infos: Vec::new(),\n            },\n            outputs,\n        })\n    } else {\n        let named = if multi_input {\n            run_onnx_inference_multi_named(effective_onnx, tensor_cache, &inputs)?\n        } else {\n            let input_tensor = tensor_cache.gather(&inputs)?;\n            run_onnx_inference_named(effective_onnx, &input_tensor)?\n        };\n        let outputs = collect_named_outputs(&meta.dependencies.output, named)?;\n\n        Ok(crate::schema::execution::StrategyOutput {\n            info: ExecutionInfo {\n                method: ExecutionMethod::OnnxOnly,\n                success: true,\n                error: None,\n                witness_file: None,\n                tile_exec_infos: Vec::new(),\n            },\n            outputs,\n        })\n    }\n}\n\n#[cfg(test)]\nfn store_named_outputs(\n    tensor_cache: &mut TensorStore,\n    output_names: &[String],\n    named_outputs: HashMap<String, (Vec<f64>, Vec<usize>)>,\n) -> Result<()> {\n    for (name, tensor) in collect_named_outputs(output_names, named_outputs)? {\n        tensor_cache.put(name, tensor);\n    }\n    Ok(())\n}\n\nfn collect_named_outputs(\n    output_names: &[String],\n    mut named_outputs: HashMap<String, (Vec<f64>, Vec<usize>)>,\n) -> Result<Vec<(String, ArrayD<f64>)>> {\n    let mut seen = std::collections::HashSet::new();\n    let mut result = Vec::new();\n    for name in output_names {\n        if !seen.insert(name) {\n            return Err(DsperseError::Pipeline(format!(\n                \"duplicate declared output '{name}'\"\n            )));\n        }\n        let (data, shape) = named_outputs\n            .remove(name)\n            .ok_or_else(|| DsperseError::Pipeline(format!(\"missing declared output '{name}'\")))?;\n        let arr = ArrayD::from_shape_vec(IxDyn(&shape), data)\n            .map_err(|e| DsperseError::Pipeline(format!(\"output reshape '{name}': {e}\")))?;\n        result.push((name.clone(), arr));\n    }\n    Ok(result)\n}\n\npub(crate) fn run_onnx_inference(onnx_path: &Path, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {\n    let input_flat: Vec<f64> = input.iter().copied().collect();\n    let input_shape = input.shape();\n    let (output_data, output_shape) =\n        crate::backend::onnx::run_inference(onnx_path, &input_flat, input_shape)?;\n\n    ArrayD::from_shape_vec(IxDyn(&output_shape), output_data)\n        .map_err(|e| DsperseError::Pipeline(format!(\"output reshape: {e}\")))\n}\n\npub(crate) fn run_onnx_inference_named(\n    onnx_path: &Path,\n    input: &ArrayD<f64>,\n) -> Result<NamedOutputs> {\n    let input_flat: Vec<f64> = input.iter().copied().collect();\n    let input_shape = input.shape();\n    crate::backend::onnx::run_inference_named(onnx_path, &input_flat, input_shape)\n}\n\npub(crate) fn run_onnx_inference_multi_named(\n    onnx_path: &Path,\n    tensor_cache: &TensorStore,\n    input_names: &[String],\n) -> Result<NamedOutputs> {\n    let inputs: Vec<(&str, Vec<f64>, Vec<usize>)> = input_names\n        .iter()\n        .map(|name| {\n            let arr = tensor_cache.get(name)?;\n            Ok((\n                name.as_str(),\n                arr.iter().copied().collect(),\n                arr.shape().to_vec(),\n            ))\n        })\n        .collect::<Result<Vec<_>>>()?;\n    crate::backend::onnx::run_inference_multi_named(onnx_path, &inputs)\n}\n\npub(crate) fn build_execution_chain(\n    model_meta: &ModelMetadata,\n    slices_dir: &Path,\n) -> Result<ExecutionChain> {\n    let mut nodes = HashMap::new();\n    let mut head = None;\n\n    for (i, slice) in model_meta.slices.iter().enumerate() {\n        let slice_id = format!(\"slice_{}\", slice.index);\n        let slice_dir = slice_dir_path(slices_dir, slice.index);\n\n        if i == 0 {\n            head = Some(slice_id.clone());\n        }\n\n        let bundle = slice_dir.join(\"jstprove/circuit.bundle\");\n        let (has_circuit, circuit_path) = if bundle.is_dir() {\n            let rel = format!(\"slice_{}/jstprove/circuit.bundle\", slice.index);\n            (true, Some(rel))\n        } else {\n            (false, None)\n        };\n        let next = model_meta\n            .slices\n            .get(i + 1)\n            .map(|s| format!(\"slice_{}\", s.index));\n\n        let onnx_path = Some(\n            slice\n                .resolve_onnx(slices_dir)?\n                .to_string_lossy()\n                .into_owned(),\n        );\n\n        let backend = if has_circuit {\n            BackendKind::Jstprove\n        } else {\n            BackendKind::Onnx\n        };\n\n        nodes.insert(\n            slice_id.clone(),\n            ExecutionNode {\n                slice_id: slice_id.clone(),\n                primary: Some(backend.to_string()),\n                fallbacks: if has_circuit {\n                    vec![\"onnx\".into()]\n                } else {\n                    Vec::new()\n                },\n                use_circuit: has_circuit,\n                next,\n                circuit_path,\n                onnx_path,\n                backend,\n            },\n        );\n    }\n\n    Ok(ExecutionChain {\n        head,\n        nodes,\n        fallback_map: HashMap::new(),\n        execution_results: Vec::new(),\n        jstprove_proved_slices: 0,\n        jstprove_verified_slices: 0,\n    })\n}\n\npub(crate) fn build_run_metadata(\n    model_meta: &ModelMetadata,\n    slices_dir: &Path,\n    chain: &ExecutionChain,\n) -> Result<RunMetadata> {\n    let mut slices = HashMap::new();\n\n    for slice in &model_meta.slices {\n        let slice_id = format!(\"slice_{}\", slice.index);\n        let node = chain.nodes.get(&slice_id);\n        let has_circuit = node.is_some_and(|n| n.use_circuit);\n\n        let run_slice = RunSliceMetadata {\n            path: slice\n                .resolve_onnx(slices_dir)?\n                .to_string_lossy()\n                .into_owned(),\n            input_shape: slice.shape.tensor_shape.input.clone(),\n            output_shape: slice.shape.tensor_shape.output.clone(),\n            dependencies: slice.dependencies.clone(),\n            tiling: slice.tiling.clone(),\n            channel_split: slice.channel_split.clone(),\n            dim_split: slice.dim_split.clone(),\n            backend: if has_circuit {\n                BackendKind::Jstprove\n            } else {\n                BackendKind::Onnx\n            },\n            jstprove_circuit_path: node.and_then(|n| n.circuit_path.clone()),\n            jstprove_settings_path: None,\n        };\n\n        slices.insert(slice_id, run_slice);\n    }\n\n    Ok(RunMetadata {\n        slices,\n        execution_chain: chain.clone(),\n        packaging_type: None,\n        source_path: Some(slices_dir.to_string_lossy().into_owned()),\n        run_directory: None,\n        model_path: None,\n    })\n}\n\npub(crate) fn extract_initializers_from_map(\n    init_map: &HashMap<String, &TensorProto>,\n    params: &CircuitParams,\n) -> Result<Vec<(Vec<f64>, Vec<usize>)>> {\n    let mut initializers = Vec::new();\n    for io in &params.inputs {\n        if let Some(tensor) = init_map.get(&io.name) {\n            let f32_vals = crate::slicer::onnx_proto::tensor_to_f32(tensor);\n            let mut f64_vals: Vec<f64> = f32_vals.iter().map(|&v| f64::from(v)).collect();\n            let target_shape = &io.shape;\n            let tensor_shape: Vec<usize> = tensor.dims.iter().map(|&d| d as usize).collect();\n            let target_elems: usize = target_shape.iter().product();\n            if f64_vals.len() < target_elems && !target_shape.is_empty() && !tensor_shape.is_empty()\n            {\n                let is_bias = tensor_shape.len() == 1;\n                let pad_val: f64 = if is_bias { -10.0 } else { 0.0 };\n                let last = target_shape.len() - 1;\n                let target_last = target_shape[last];\n                let donor_last = tensor_shape[last];\n                if donor_last < target_last {\n                    let rows = f64_vals.len() / donor_last.max(1);\n                    let mut padded = Vec::with_capacity(target_elems);\n                    for row in 0..rows {\n                        let start = row * donor_last;\n                        let end = start + donor_last;\n                        padded.extend_from_slice(&f64_vals[start..end.min(f64_vals.len())]);\n                        padded.resize(padded.len() + (target_last - donor_last), pad_val);\n                    }\n                    f64_vals = padded;\n                }\n            }\n            let shape: Vec<usize> = if f64_vals.len() == target_elems {\n                target_shape.clone()\n            } else {\n                tensor_shape\n            };\n            initializers.push((f64_vals, shape));\n        }\n    }\n    Ok(initializers)\n}\n\npub fn extract_onnx_initializers(\n    onnx_path: &Path,\n    params: &CircuitParams,\n) -> Result<Vec<(Vec<f64>, Vec<usize>)>> {\n    let model = crate::slicer::onnx_proto::load_model(onnx_path)?;\n    let graph = model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Pipeline(\"ONNX model missing graph\".into()))?;\n    let init_map = crate::slicer::onnx_proto::build_initializer_map(graph);\n    extract_initializers_from_map(&init_map, params)\n}\n\npub(crate) fn flatten_cached_inputs(cache: &TensorStore, names: &[String]) -> Result<Vec<f64>> {\n    let arrays: Vec<&ArrayD<f64>> = names.iter().map(|n| cache.get(n)).collect::<Result<_>>()?;\n    let total: usize = arrays.iter().map(|a| a.len()).sum();\n    let mut flat = Vec::with_capacity(total);\n    for arr in arrays {\n        flat.extend(arr.iter());\n    }\n    Ok(flat)\n}\n\npub(crate) fn generate_wai_witness(\n    backend: &JstproveBackend,\n    circuit_path: &Path,\n    slice_onnx_path: &Path,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n    params: &CircuitParams,\n    flat_activations: &[f64],\n) -> Result<Vec<u8>> {\n    let initializers = if let Some(donor) = donor_init_map {\n        let slice_model = crate::slicer::onnx_proto::load_model(slice_onnx_path)?;\n        let slice_graph = slice_model\n            .graph\n            .as_ref()\n            .ok_or_else(|| DsperseError::Pipeline(\"slice ONNX missing graph\".into()))?;\n        let mut merged = crate::slicer::onnx_proto::build_initializer_map(slice_graph);\n        for (k, v) in donor.iter() {\n            merged.insert(k.clone(), *v);\n        }\n        extract_initializers_from_map(&merged, params)?\n    } else {\n        extract_onnx_initializers(slice_onnx_path, params)?\n    };\n    backend.witness_f64(circuit_path, flat_activations, &initializers)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::super::tiled::{reconstruct_from_tiles, reshape_to_4d, split_into_tiles};\n    use super::*;\n    use crate::schema::tiling::TilingInfo;\n    use ndarray::Array4;\n\n    fn make_tiling(\n        tile_size: usize,\n        tiles_y: usize,\n        tiles_x: usize,\n        halo: [i64; 4],\n        out_tile: [i64; 2],\n        c_out: usize,\n    ) -> TilingInfo {\n        TilingInfo {\n            slice_idx: 0,\n            tile_size,\n            num_tiles: tiles_y * tiles_x,\n            tiles_y,\n            tiles_x,\n            halo,\n            out_tile,\n            stride: [1, 1],\n            c_in: 1,\n            c_out,\n            input_name: \"input\".into(),\n            output_name: \"output\".into(),\n            input_names: vec![],\n            ndim: 4,\n            h: tiles_y * tile_size,\n            w: tiles_x * tile_size,\n            tile: None,\n            tiles: None,\n            segment_size: None,\n            total_elements: None,\n            original_shape: vec![],\n        }\n    }\n\n    #[test]\n    fn reshape_to_4d_valid() {\n        let data: Vec<f64> = (0..24).map(|i| i as f64).collect();\n        let arr = reshape_to_4d(&data, 2, 3, 4).unwrap();\n        assert_eq!(arr.dim(), (1, 2, 3, 4));\n    }\n\n    #[test]\n    fn reshape_to_4d_single_element() {\n        let data = vec![42.0];\n        let arr = reshape_to_4d(&data, 1, 1, 1).unwrap();\n        assert_eq!(arr.dim(), (1, 1, 1, 1));\n        assert_eq!(arr[[0, 0, 0, 0]], 42.0);\n    }\n\n    #[test]\n    fn reshape_to_4d_mismatch() {\n        let data = vec![1.0; 10];\n        assert!(reshape_to_4d(&data, 2, 3, 4).is_err());\n    }\n\n    #[test]\n    fn reshape_to_4d_empty() {\n        let data: Vec<f64> = vec![];\n        assert!(reshape_to_4d(&data, 1, 1, 1).is_err());\n    }\n\n    #[test]\n    fn split_into_tiles_2x2_no_halo() {\n        let input =\n            Array4::from_shape_vec((1, 1, 4, 4), (0..16).map(|i| i as f64).collect()).unwrap();\n        let tiling = make_tiling(2, 2, 2, [0, 0, 0, 0], [2, 2], 1);\n        let tiles = split_into_tiles(&input, &tiling).unwrap();\n        assert_eq!(tiles.len(), 4);\n        for tile in &tiles {\n            assert_eq!(tile.dim(), (1, 1, 2, 2));\n        }\n    }\n\n    #[test]\n    fn split_into_tiles_with_halo() {\n        let input =\n            Array4::from_shape_vec((1, 1, 4, 4), (0..16).map(|i| i as f64).collect()).unwrap();\n        let tiling = make_tiling(2, 2, 2, [1, 1, 1, 1], [2, 2], 1);\n        let tiles = split_into_tiles(&input, &tiling).unwrap();\n        assert_eq!(tiles.len(), 4);\n        for tile in &tiles {\n            assert_eq!(tile.dim(), (1, 1, 4, 4));\n        }\n    }\n\n    #[test]\n    fn split_into_tiles_negative_halo_rejected() {\n        let input = Array4::zeros((1, 1, 4, 4));\n        let tiling = make_tiling(2, 2, 2, [-1, 0, 0, 0], [2, 2], 1);\n        assert!(split_into_tiles(&input, &tiling).is_err());\n    }\n\n    #[test]\n    fn split_into_tiles_batch_gt1_rejected() {\n        let input = Array4::zeros((2, 1, 4, 4));\n        let tiling = make_tiling(2, 1, 1, [0, 0, 0, 0], [2, 2], 1);\n        assert!(split_into_tiles(&input, &tiling).is_err());\n    }\n\n    #[test]\n    fn reconstruct_from_tiles_2x2() {\n        let c_out = 1;\n        let out_h = 2usize;\n        let out_w = 2usize;\n        let tiling = make_tiling(4, 2, 2, [0, 0, 0, 0], [out_h as i64, out_w as i64], c_out);\n\n        let tiles: Vec<ArrayD<f64>> = (0..4)\n            .map(|i| {\n                ArrayD::from_shape_vec(\n                    IxDyn(&[1, c_out, out_h, out_w]),\n                    vec![i as f64; c_out * out_h * out_w],\n                )\n                .unwrap()\n            })\n            .collect();\n\n        let output = reconstruct_from_tiles(&tiles, &tiling).unwrap();\n        assert_eq!(output.shape(), &[1, c_out, 4, 4]);\n    }\n\n    #[test]\n    fn reconstruct_from_tiles_empty() {\n        let tiling = make_tiling(2, 1, 1, [0, 0, 0, 0], [2, 2], 1);\n        assert!(reconstruct_from_tiles(&[], &tiling).is_err());\n    }\n\n    #[test]\n    fn reconstruct_from_tiles_wrong_element_count() {\n        let tiling = make_tiling(2, 1, 1, [0, 0, 0, 0], [2, 2], 1);\n        let bad_tile = vec![ArrayD::from_shape_vec(IxDyn(&[3]), vec![1.0; 3]).unwrap()];\n        assert!(reconstruct_from_tiles(&bad_tile, &tiling).is_err());\n    }\n\n    #[test]\n    fn reconstruct_from_tiles_wrong_tile_count() {\n        let c_out = 1;\n        let out_h = 2i64;\n        let out_w = 2i64;\n        let tiling = make_tiling(4, 2, 2, [0, 0, 0, 0], [out_h, out_w], c_out);\n        let make_tile = || {\n            ArrayD::from_shape_vec(\n                IxDyn(&[1, c_out, out_h as usize, out_w as usize]),\n                vec![0.0f64; c_out * out_h as usize * out_w as usize],\n            )\n            .unwrap()\n        };\n        let too_few: Vec<ArrayD<f64>> = (0..3).map(|_| make_tile()).collect();\n        assert!(reconstruct_from_tiles(&too_few, &tiling).is_err());\n        let too_many: Vec<ArrayD<f64>> = (0..5).map(|_| make_tile()).collect();\n        assert!(reconstruct_from_tiles(&too_many, &tiling).is_err());\n    }\n\n    #[test]\n    fn split_reconstruct_roundtrip() {\n        let c = 2;\n        let h = 8;\n        let w = 8;\n        let data: Vec<f64> = (0..(c * h * w)).map(|i| i as f64).collect();\n        let input = Array4::from_shape_vec((1, c, h, w), data).unwrap();\n\n        let tile_size = 4;\n        let tiling = make_tiling(tile_size, 2, 2, [0, 0, 0, 0], [4, 4], c);\n\n        let tiles = split_into_tiles(&input, &tiling).unwrap();\n        assert_eq!(tiles.len(), 4);\n\n        let tile_outputs: Vec<ArrayD<f64>> = tiles.into_iter().map(|t| t.into_dyn()).collect();\n        let reconstructed = reconstruct_from_tiles(&tile_outputs, &tiling).unwrap();\n        assert_eq!(reconstructed.shape(), &[1, c, h, w]);\n\n        let input_dyn = input.into_dyn();\n        assert_eq!(input_dyn, reconstructed);\n    }\n\n    #[test]\n    fn store_named_outputs_basic() {\n        let mut cache = TensorStore::new();\n        let names = vec![\"out_a\".to_string(), \"out_b\".to_string()];\n        let mut named = HashMap::new();\n        named.insert(\"out_a\".to_string(), (vec![1.0, 2.0], vec![2]));\n        named.insert(\"out_b\".to_string(), (vec![3.0], vec![1]));\n\n        store_named_outputs(&mut cache, &names, named).unwrap();\n        assert_eq!(cache.get(\"out_a\").unwrap().shape(), &[2]);\n        assert_eq!(cache.get(\"out_b\").unwrap().shape(), &[1]);\n    }\n\n    #[test]\n    fn store_named_outputs_missing_name_errors() {\n        let mut cache = TensorStore::new();\n        let names = vec![\"missing\".to_string()];\n        let named = HashMap::new();\n        let result = store_named_outputs(&mut cache, &names, named);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn store_named_outputs_partial_write_errors() {\n        let mut cache = TensorStore::new();\n        cache.put(\n            \"pre_existing\".into(),\n            ArrayD::from_shape_vec(ndarray::IxDyn(&[1]), vec![99.0]).unwrap(),\n        );\n        let names = vec![\"present\".to_string(), \"missing\".to_string()];\n        let mut named = HashMap::new();\n        named.insert(\"present\".to_string(), (vec![1.0, 2.0], vec![2]));\n        let result = store_named_outputs(&mut cache, &names, named);\n        assert!(result.is_err());\n        assert!(cache.contains(\"pre_existing\"));\n        assert!(!cache.contains(\"present\"));\n    }\n\n    #[test]\n    fn run_config_default() {\n        let config = RunConfig::default();\n        assert_eq!(config.parallel, 1);\n        assert!(!config.batch);\n        assert!(config.weights_onnx.is_none());\n        assert!(config.combined);\n    }\n\n    #[test]\n    fn multi_input_activation_concatenation_ordering() {\n        use ndarray::IxDyn;\n        let mut cache = TensorStore::new();\n        cache.put(\n            \"act_a\".into(),\n            ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),\n        );\n        cache.put(\n            \"act_b\".into(),\n            ArrayD::from_shape_vec(IxDyn(&[2]), vec![7.0, 8.0]).unwrap(),\n        );\n        cache.put(\n            \"act_c\".into(),\n            ArrayD::from_shape_vec(IxDyn(&[1]), vec![9.0]).unwrap(),\n        );\n\n        let inputs = vec![\n            \"act_a\".to_string(),\n            \"act_b\".to_string(),\n            \"act_c\".to_string(),\n        ];\n        let mut flat: Vec<f64> = Vec::new();\n        for name in &inputs {\n            let arr = cache.get(name).unwrap();\n            flat.extend(arr.iter());\n        }\n\n        assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);\n    }\n\n    #[test]\n    fn multi_input_activation_missing_tensor_error() {\n        let mut cache = TensorStore::new();\n        cache.put(\n            \"act_a\".into(),\n            ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[2]), vec![1.0, 2.0]).unwrap(),\n        );\n\n        let inputs = vec![\"act_a\".to_string(), \"act_missing\".to_string()];\n        let mut flat: Vec<f64> = Vec::new();\n        let mut err = None;\n        for name in &inputs {\n            match cache.get(name) {\n                Ok(arr) => flat.extend(arr.iter()),\n                Err(e) => {\n                    err = Some(e);\n                    break;\n                }\n            }\n        }\n\n        assert!(err.is_some());\n        assert!(err.unwrap().to_string().contains(\"act_missing\"));\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/slice_cache.rs",
    "content": "use std::io::Read;\nuse std::path::Path;\n\nuse crate::error::{DsperseError, Result};\n\npub struct SliceAssets {\n    pub circuit_bytes: Option<Vec<u8>>,\n    pub onnx_bytes: Option<Vec<u8>>,\n}\n\nimpl SliceAssets {\n    pub fn load_from_dslice(slices_dir: &Path, slice_id: &str) -> Result<Self> {\n        let archive_path = slices_dir.join(format!(\"{slice_id}.dslice\"));\n        if !archive_path.exists() {\n            return Ok(Self {\n                circuit_bytes: None,\n                onnx_bytes: None,\n            });\n        }\n\n        let file =\n            std::fs::File::open(&archive_path).map_err(|e| DsperseError::io(e, &archive_path))?;\n        let mut zip = zip::ZipArchive::new(file).map_err(|e| {\n            DsperseError::Slicer(format!(\n                \"reading dslice archive {}: {e}\",\n                archive_path.display()\n            ))\n        })?;\n\n        let mut circuit_bytes = None;\n        let mut onnx_bytes = None;\n\n        for i in 0..zip.len() {\n            let mut entry = zip\n                .by_index(i)\n                .map_err(|e| DsperseError::Slicer(format!(\"reading zip entry {i}: {e}\")))?;\n            let name = entry.name().to_string();\n\n            if name.ends_with(\"circuit.bin\") {\n                let mut buf = Vec::with_capacity(entry.size() as usize);\n                entry.read_to_end(&mut buf).map_err(|e| {\n                    DsperseError::Slicer(format!(\"reading circuit.bin from dslice: {e}\"))\n                })?;\n                circuit_bytes = Some(buf);\n            } else if name.ends_with(\".onnx\") && name.starts_with(\"payload/\") {\n                let mut buf = Vec::with_capacity(entry.size() as usize);\n                entry\n                    .read_to_end(&mut buf)\n                    .map_err(|e| DsperseError::Slicer(format!(\"reading onnx from dslice: {e}\")))?;\n                onnx_bytes = Some(buf);\n            }\n        }\n\n        Ok(Self {\n            circuit_bytes,\n            onnx_bytes,\n        })\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/stage.rs",
    "content": "use std::path::Path;\n\nuse rayon::prelude::*;\n\nuse crate::backend::ProofBackend;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::execution::{ExecutionMethod, RunMetadata, SliceResult, TileResult};\nuse crate::schema::metadata::RunSliceMetadata;\nuse crate::schema::tiling::TilingInfo;\nuse crate::utils::paths::resolve_relative_path;\n\nuse super::tile_executor::resolve_tile_circuit;\n\n#[derive(Debug, Clone, Copy)]\npub enum PipelineStage {\n    Prove,\n    Verify,\n}\n\nimpl PipelineStage {\n    fn execution_method(&self) -> ExecutionMethod {\n        match self {\n            Self::Prove => ExecutionMethod::JstproveProve,\n            Self::Verify => ExecutionMethod::JstproveVerify,\n        }\n    }\n\n    fn action_label(&self) -> &'static str {\n        match self {\n            Self::Prove => \"proving\",\n            Self::Verify => \"verifying\",\n        }\n    }\n\n    fn past_label(&self) -> &'static str {\n        match self {\n            Self::Prove => \"proved\",\n            Self::Verify => \"verified\",\n        }\n    }\n\n    fn error_label(&self) -> &'static str {\n        match self {\n            Self::Prove => \"proof\",\n            Self::Verify => \"verification\",\n        }\n    }\n}\n\npub fn run_pipeline_stage(\n    stage: PipelineStage,\n    run_dir: &Path,\n    slices_dir: &Path,\n    backend: &dyn ProofBackend,\n    parallel: usize,\n) -> Result<RunMetadata> {\n    let meta_path = run_dir.join(crate::utils::paths::METADATA_FILE);\n    let data = crate::utils::limits::read_checked(&meta_path)?;\n    let mut run_meta: RunMetadata = rmp_serde::from_slice(&data)?;\n\n    let circuit_slices: Vec<(String, _)> = run_meta\n        .iter_circuit_slices()\n        .map(|(id, meta)| (id.to_string(), meta.clone()))\n        .collect();\n\n    tracing::info!(\n        total = circuit_slices.len(),\n        \"{} circuit slices\",\n        stage.action_label()\n    );\n\n    let pool = rayon::ThreadPoolBuilder::new()\n        .num_threads(parallel)\n        .build()\n        .map_err(|e| DsperseError::Pipeline(format!(\"thread pool: {e}\")))?;\n\n    let results: Vec<_> = pool.install(|| {\n        circuit_slices\n            .par_iter()\n            .map(|(slice_id, meta)| {\n                if slice_id.strip_prefix(\"slice_\").and_then(|s| s.parse::<usize>().ok()).is_none() {\n                    return (\n                        slice_id.clone(),\n                        Err(DsperseError::Pipeline(format!(\n                            \"invalid slice_id format: {slice_id:?}\"\n                        ))),\n                    );\n                }\n                let slice_run_dir = run_dir.join(slice_id);\n\n                let result =\n                    execute_single_slice(stage, slices_dir, &slice_run_dir, slice_id, meta, backend);\n\n                match &result {\n                    Ok(r) if r.success => tracing::info!(slice = %slice_id, \"{}\", stage.past_label()),\n                    Ok(r) => tracing::error!(\n                        slice = %slice_id,\n                        error = r.error.as_deref().unwrap_or(\"unknown\"),\n                        \"{} failed\", stage.error_label()\n                    ),\n                    Err(e) => tracing::error!(slice = %slice_id, error = %e, \"{} error\", stage.error_label()),\n                }\n\n                (slice_id.clone(), result)\n            })\n            .collect()\n    });\n\n    let method = stage.execution_method();\n    let mut succeeded = 0;\n    for (slice_id, result) in results {\n        let slice_result = match result {\n            Ok(r) => {\n                if r.success {\n                    succeeded += 1;\n                }\n                r\n            }\n            Err(e) => SliceResult::failure(slice_id.clone(), method, e.to_string(), 0.0),\n        };\n\n        if let Some(entry) = run_meta\n            .execution_chain\n            .execution_results\n            .iter_mut()\n            .find(|e| e.slice_id == slice_id)\n        {\n            match stage {\n                PipelineStage::Prove => entry.proof_execution = Some(slice_result),\n                PipelineStage::Verify => entry.verification_execution = Some(slice_result),\n            }\n        } else {\n            tracing::warn!(\n                slice = %slice_id,\n                stage = ?stage,\n                success = slice_result.success,\n                error = slice_result.error.as_deref().unwrap_or(\"none\"),\n                \"no matching execution_results entry, result dropped\"\n            );\n        }\n    }\n\n    match stage {\n        PipelineStage::Prove => run_meta.execution_chain.jstprove_proved_slices = succeeded,\n        PipelineStage::Verify => run_meta.execution_chain.jstprove_verified_slices = succeeded,\n    }\n\n    let meta_bytes = rmp_serde::to_vec_named(&run_meta)?;\n    std::fs::write(&meta_path, meta_bytes).map_err(|e| DsperseError::io(e, &meta_path))?;\n\n    tracing::info!(\n        succeeded,\n        total = circuit_slices.len(),\n        \"{} complete\",\n        stage.action_label()\n    );\n    Ok(run_meta)\n}\n\nfn execute_single_slice(\n    stage: PipelineStage,\n    slices_dir: &Path,\n    slice_run_dir: &Path,\n    slice_id: &str,\n    meta: &RunSliceMetadata,\n    backend: &dyn ProofBackend,\n) -> Result<SliceResult> {\n    if let Some(ref tiling) = meta.tiling {\n        let default_circuit_path = meta\n            .jstprove_circuit_path\n            .as_deref()\n            .map(|p| resolve_relative_path(slices_dir, p))\n            .transpose()?;\n        return execute_tiled_stage(\n            stage,\n            slice_id,\n            default_circuit_path.as_deref(),\n            slice_run_dir,\n            tiling,\n            slices_dir,\n            backend,\n        );\n    }\n\n    let circuit_path = meta\n        .jstprove_circuit_path\n        .as_deref()\n        .map(|p| resolve_relative_path(slices_dir, p))\n        .transpose()?\n        .ok_or_else(|| DsperseError::Pipeline(format!(\"no circuit path for {slice_id}\")))?;\n\n    let start = std::time::Instant::now();\n    let method = stage.execution_method();\n    let witness_path = slice_run_dir.join(crate::utils::paths::WITNESS_FILE);\n    let witness_bytes = match crate::utils::limits::read_checked(&witness_path) {\n        Ok(b) => b,\n        Err(e) => {\n            return Ok(SliceResult::failure(\n                slice_id,\n                method,\n                format!(\"witness file read error: {}: {e}\", witness_path.display()),\n                start.elapsed().as_secs_f64(),\n            ));\n        }\n    };\n\n    execute_stage_operation(\n        stage,\n        slice_id,\n        &circuit_path,\n        &witness_bytes,\n        slice_run_dir,\n        backend,\n        start,\n        method,\n    )\n}\n\n#[allow(clippy::too_many_arguments)]\nfn execute_stage_operation(\n    stage: PipelineStage,\n    slice_id: &str,\n    circuit_path: &Path,\n    witness_bytes: &[u8],\n    output_dir: &Path,\n    backend: &dyn ProofBackend,\n    start: std::time::Instant,\n    method: ExecutionMethod,\n) -> Result<SliceResult> {\n    match stage {\n        PipelineStage::Prove => {\n            let proof_bytes = backend.prove(circuit_path, witness_bytes)?;\n            let proof_path = output_dir.join(crate::utils::paths::PROOF_FILE);\n            std::fs::write(&proof_path, &proof_bytes)\n                .map_err(|e| DsperseError::io(e, &proof_path))?;\n\n            let mut result = SliceResult::success(slice_id, method, start.elapsed().as_secs_f64());\n            result.proof_path = Some(proof_path.to_string_lossy().into_owned());\n            Ok(result)\n        }\n        PipelineStage::Verify => {\n            let proof_path = output_dir.join(crate::utils::paths::PROOF_FILE);\n            let proof_bytes = match crate::utils::limits::read_checked(&proof_path) {\n                Ok(b) => b,\n                Err(e) => {\n                    return Ok(SliceResult::failure(\n                        slice_id,\n                        method,\n                        format!(\"proof file read error: {}: {e}\", proof_path.display()),\n                        start.elapsed().as_secs_f64(),\n                    ));\n                }\n            };\n\n            let valid = backend.verify(circuit_path, witness_bytes, &proof_bytes)?;\n\n            let elapsed = start.elapsed().as_secs_f64();\n            let mut result = if valid {\n                SliceResult::success(slice_id, method, elapsed)\n            } else {\n                SliceResult::failure(\n                    slice_id,\n                    method,\n                    \"proof verification failed\".into(),\n                    elapsed,\n                )\n            };\n            result.proof_path = Some(proof_path.to_string_lossy().into_owned());\n            Ok(result)\n        }\n    }\n}\n\nfn execute_tiled_stage(\n    stage: PipelineStage,\n    slice_id: &str,\n    default_circuit_path: Option<&Path>,\n    slice_run_dir: &Path,\n    tiling: &TilingInfo,\n    slices_dir: &Path,\n    backend: &dyn ProofBackend,\n) -> Result<SliceResult> {\n    if tiling.num_tiles == 0 {\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: tiling.num_tiles is 0\"\n        )));\n    }\n\n    let start = std::time::Instant::now();\n    let method = stage.execution_method();\n\n    let tile_results: Vec<TileResult> = (0..tiling.num_tiles)\n        .into_par_iter()\n        .map(|tile_idx| {\n            let tile_start = std::time::Instant::now();\n            let fail = |error: String| {\n                TileResult::failure(\n                    tile_idx,\n                    error,\n                    Some(method),\n                    tile_start.elapsed().as_secs_f64(),\n                )\n            };\n            let tile_dir = slice_run_dir.join(format!(\"tile_{tile_idx}\"));\n\n            let tile_circuit_path =\n                match resolve_tile_circuit(tiling, tile_idx, slices_dir, default_circuit_path) {\n                    Ok(Some(p)) => p,\n                    Ok(None) => return fail(format!(\"no circuit path for tile {tile_idx}\")),\n                    Err(e) => return fail(e),\n                };\n\n            let witness_path = tile_dir.join(crate::utils::paths::WITNESS_FILE);\n            let witness_bytes = match crate::utils::limits::read_checked(&witness_path) {\n                Ok(b) => b,\n                Err(e) => {\n                    return fail(format!(\n                        \"witness read error: {}: {e}\",\n                        witness_path.display()\n                    ));\n                }\n            };\n\n            execute_tile_stage_operation(\n                stage,\n                tile_idx,\n                &tile_circuit_path,\n                &witness_bytes,\n                &tile_dir,\n                backend,\n                method,\n                tile_start,\n            )\n        })\n        .collect();\n\n    let failed = tile_results.iter().filter(|t| !t.success).count();\n    let all_success = failed == 0;\n\n    let elapsed = start.elapsed().as_secs_f64();\n    let mut result = if all_success {\n        SliceResult::success(slice_id, method, elapsed)\n    } else {\n        SliceResult::failure(\n            slice_id,\n            method,\n            format!(\"{failed} of {} tiles failed\", tiling.num_tiles),\n            elapsed,\n        )\n    };\n    result.tiles = tile_results;\n    Ok(result)\n}\n\n#[allow(clippy::too_many_arguments)]\nfn execute_tile_stage_operation(\n    stage: PipelineStage,\n    tile_idx: usize,\n    circuit_path: &Path,\n    witness_bytes: &[u8],\n    tile_dir: &Path,\n    backend: &dyn ProofBackend,\n    method: ExecutionMethod,\n    tile_start: std::time::Instant,\n) -> TileResult {\n    let fail = |error: String| {\n        TileResult::failure(\n            tile_idx,\n            error,\n            Some(method),\n            tile_start.elapsed().as_secs_f64(),\n        )\n    };\n\n    match stage {\n        PipelineStage::Prove => {\n            let proof_bytes = match backend.prove(circuit_path, witness_bytes) {\n                Ok(b) => b,\n                Err(e) => return fail(e.to_string()),\n            };\n            let proof_path = tile_dir.join(crate::utils::paths::PROOF_FILE);\n            if let Err(e) = std::fs::write(&proof_path, &proof_bytes) {\n                return fail(format!(\"write proof: {}: {e}\", proof_path.display()));\n            }\n            let mut result =\n                TileResult::success(tile_idx, Some(method), tile_start.elapsed().as_secs_f64());\n            result.proof_path = Some(proof_path.to_string_lossy().into_owned());\n            result\n        }\n        PipelineStage::Verify => {\n            let proof_path = tile_dir.join(crate::utils::paths::PROOF_FILE);\n            let proof_bytes = match crate::utils::limits::read_checked(&proof_path) {\n                Ok(b) => b,\n                Err(e) => {\n                    return fail(format!(\"proof read error: {}: {e}\", proof_path.display()));\n                }\n            };\n            let valid = match backend.verify(circuit_path, witness_bytes, &proof_bytes) {\n                Ok(v) => v,\n                Err(e) => return fail(e.to_string()),\n            };\n            let elapsed = tile_start.elapsed().as_secs_f64();\n            let mut result = if valid {\n                TileResult::success(tile_idx, Some(method), elapsed)\n            } else {\n                TileResult::failure(\n                    tile_idx,\n                    \"proof verification failed\".into(),\n                    Some(method),\n                    elapsed,\n                )\n            };\n            result.proof_path = Some(proof_path.to_string_lossy().into_owned());\n            result\n        }\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/strategy.rs",
    "content": "use crate::error::{DsperseError, Result};\nuse crate::schema::execution::ExecutionMethod;\nuse crate::schema::metadata::RunSliceMetadata;\nuse crate::schema::tiling::{ChannelSplitInfo, DimSplitInfo, SplitStrategy, TilingInfo};\n\npub enum ExecutionStrategy<'a> {\n    ChannelSplit(&'a ChannelSplitInfo),\n    DimSplit(&'a DimSplitInfo),\n    Tiled(&'a TilingInfo),\n    Single { use_circuit: bool },\n}\n\nimpl<'a> ExecutionStrategy<'a> {\n    pub fn from_metadata(meta: &'a RunSliceMetadata, use_circuit: bool) -> Result<Self> {\n        let has_cs = meta.channel_split.is_some();\n        let has_ds = meta.dim_split.is_some();\n        let has_tiling = meta.tiling.is_some();\n        let count = has_cs as u8 + has_ds as u8 + has_tiling as u8;\n        if count > 1 {\n            return Err(DsperseError::Metadata(format!(\n                \"slice has multiple split metadata (channel_split={has_cs}, \\\n                 dim_split={has_ds}, tiling={has_tiling}; path={:?})\",\n                meta.path\n            )));\n        }\n        match meta.split_strategy() {\n            Some(SplitStrategy::ChannelSplit(cs)) => Ok(Self::ChannelSplit(cs)),\n            Some(SplitStrategy::DimSplit(ds)) => {\n                if ds.template_path.is_none() {\n                    // Template creation may have been rejected (axis-\n                    // separability, unsupported split kind) or the template\n                    // was not included in the bundle. Fall back to the\n                    // non-template Single execution path (which may still\n                    // use circuit-based witness generation if use_circuit is\n                    // set) so already-published bundles with template-less\n                    // dim_split metadata remain runnable.\n                    tracing::debug!(\n                        path = ?meta.path,\n                        split_kind = ?ds.split_kind,\n                        \"dim_split template_path missing, falling back to single execution\"\n                    );\n                    Ok(Self::Single { use_circuit })\n                } else {\n                    Ok(Self::DimSplit(ds))\n                }\n            }\n            Some(SplitStrategy::Tiled(t)) => Ok(Self::Tiled(t)),\n            None => Ok(Self::Single { use_circuit }),\n        }\n    }\n\n    pub fn execution_method(&self) -> ExecutionMethod {\n        match self {\n            Self::ChannelSplit(_) => ExecutionMethod::ChannelSplit,\n            Self::DimSplit(_) => ExecutionMethod::DimSplit,\n            Self::Tiled(_) => ExecutionMethod::Tiled,\n            Self::Single { use_circuit: true } => ExecutionMethod::JstproveGenWitness,\n            Self::Single { use_circuit: false } => ExecutionMethod::OnnxOnly,\n        }\n    }\n\n    pub fn output_name(&self) -> Option<&str> {\n        match self {\n            Self::ChannelSplit(cs) => Some(&cs.output_name),\n            Self::DimSplit(ds) => Some(&ds.output_name),\n            Self::Tiled(tiling) => Some(&tiling.output_name),\n            Self::Single { .. } => None,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/tensor_store.rs",
    "content": "use std::collections::HashMap;\n\nuse ndarray::ArrayD;\n\nuse crate::error::{DsperseError, Result};\n\n#[derive(Default)]\npub struct TensorStore {\n    tensors: HashMap<String, ArrayD<f64>>,\n}\n\nimpl TensorStore {\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    pub fn get(&self, name: &str) -> Result<&ArrayD<f64>> {\n        self.tensors\n            .get(name)\n            .ok_or_else(|| DsperseError::Pipeline(format!(\"tensor '{name}' not found in store\")))\n    }\n\n    pub fn try_get(&self, name: &str) -> Option<&ArrayD<f64>> {\n        self.tensors.get(name)\n    }\n\n    pub fn put(&mut self, name: String, tensor: ArrayD<f64>) {\n        self.tensors.insert(name, tensor);\n    }\n\n    pub fn contains(&self, name: &str) -> bool {\n        self.tensors.contains_key(name)\n    }\n\n    pub fn len(&self) -> usize {\n        self.tensors.len()\n    }\n\n    pub fn is_empty(&self) -> bool {\n        self.tensors.is_empty()\n    }\n\n    pub fn keys(&self) -> impl Iterator<Item = &String> {\n        self.tensors.keys()\n    }\n\n    pub fn as_map(&self) -> &HashMap<String, ArrayD<f64>> {\n        &self.tensors\n    }\n\n    pub fn gather(&self, names: &[String]) -> Result<ArrayD<f64>> {\n        crate::utils::io::gather_inputs_from_cache(&self.tensors, names)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use ndarray::IxDyn;\n\n    #[test]\n    fn put_and_get() {\n        let mut store = TensorStore::new();\n        let arr = ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap();\n        store.put(\"x\".into(), arr.clone());\n        assert_eq!(store.get(\"x\").unwrap(), &arr);\n    }\n\n    #[test]\n    fn get_missing_returns_error() {\n        let store = TensorStore::new();\n        assert!(store.get(\"missing\").is_err());\n    }\n\n    #[test]\n    fn try_get_missing_returns_none() {\n        let store = TensorStore::new();\n        assert!(store.try_get(\"missing\").is_none());\n    }\n\n    #[test]\n    fn contains_check() {\n        let mut store = TensorStore::new();\n        assert!(!store.contains(\"a\"));\n        store.put(\n            \"a\".into(),\n            ArrayD::from_shape_vec(IxDyn(&[1]), vec![0.0]).unwrap(),\n        );\n        assert!(store.contains(\"a\"));\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/tile_executor.rs",
    "content": "use std::path::{Path, PathBuf};\n\nuse rayon::prelude::*;\n\nuse crate::error::{DsperseError, Result};\nuse crate::schema::tiling::TilingInfo;\nuse crate::utils::paths::resolve_relative_path;\n\npub fn resolve_tile_circuit(\n    tiling: &TilingInfo,\n    tile_idx: usize,\n    slices_dir: &Path,\n    default_circuit: Option<&Path>,\n) -> std::result::Result<Option<PathBuf>, String> {\n    let from_tiles = tiling\n        .tiles\n        .as_deref()\n        .and_then(|ts| ts.get(tile_idx))\n        .and_then(|ti| ti.jstprove_circuit_path.as_deref());\n    let from_single = tiling\n        .tile\n        .as_ref()\n        .and_then(|ti| ti.jstprove_circuit_path.as_deref());\n    let path_str = from_tiles.or(from_single);\n    match path_str {\n        Some(p) => match resolve_relative_path(slices_dir, p) {\n            Ok(resolved) => Ok(Some(resolved)),\n            Err(e) => Err(e.to_string()),\n        },\n        None => Ok(default_circuit.map(|p| p.to_path_buf())),\n    }\n}\n\npub fn execute_tiles<T, F>(parallel: usize, num_tiles: usize, op: F) -> Result<Vec<T>>\nwhere\n    T: Send,\n    F: Fn(usize) -> T + Send + Sync,\n{\n    if num_tiles == 0 {\n        return Err(DsperseError::Pipeline(\"num_tiles is 0\".into()));\n    }\n\n    let pool = rayon::ThreadPoolBuilder::new()\n        .num_threads(parallel)\n        .build()\n        .map_err(|e| DsperseError::Pipeline(format!(\"thread pool: {e}\")))?;\n\n    let results = pool.install(|| (0..num_tiles).into_par_iter().map(op).collect());\n\n    Ok(results)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::schema::tiling::{TileInfo, TilingInfo};\n\n    fn make_tiling() -> TilingInfo {\n        TilingInfo {\n            slice_idx: 0,\n            tile_size: 4,\n            num_tiles: 4,\n            tiles_y: 2,\n            tiles_x: 2,\n            halo: [0, 0, 0, 0],\n            out_tile: [4, 4],\n            stride: [1, 1],\n            c_in: 1,\n            c_out: 1,\n            input_name: \"input\".into(),\n            output_name: \"output\".into(),\n            input_names: vec![],\n            ndim: 4,\n            h: 8,\n            w: 8,\n            tile: None,\n            tiles: None,\n            segment_size: None,\n            total_elements: None,\n            original_shape: vec![],\n        }\n    }\n\n    #[test]\n    fn resolve_tile_circuit_no_info() {\n        let tiling = make_tiling();\n        let result = resolve_tile_circuit(&tiling, 0, Path::new(\"/tmp\"), None);\n        assert_eq!(result.unwrap(), None);\n    }\n\n    #[test]\n    fn resolve_tile_circuit_with_default() {\n        let tiling = make_tiling();\n        let default = PathBuf::from(\"/tmp/circuit.bundle\");\n        let result = resolve_tile_circuit(&tiling, 0, Path::new(\"/tmp\"), Some(&default));\n        assert_eq!(result.unwrap(), Some(default));\n    }\n\n    #[test]\n    fn resolve_tile_circuit_from_single_tile() {\n        let mut tiling = make_tiling();\n        tiling.tile = Some(TileInfo {\n            path: \"tile.onnx\".into(),\n            conv_out: [4, 4],\n            jstprove_circuit_path: Some(\"jstprove/circuit.bundle\".into()),\n        });\n        let result = resolve_tile_circuit(&tiling, 0, Path::new(\"/slices\"), None);\n        let resolved = result.unwrap().unwrap();\n        assert!(resolved.to_string_lossy().contains(\"circuit.bundle\"));\n    }\n\n    #[test]\n    fn execute_tiles_collects_results() {\n        let results = execute_tiles(2, 4, |i| i * 2).unwrap();\n        assert_eq!(results.len(), 4);\n        let mut sorted = results.clone();\n        sorted.sort();\n        assert_eq!(sorted, vec![0, 2, 4, 6]);\n    }\n\n    #[test]\n    fn execute_tiles_zero_tiles_errors() {\n        let result = execute_tiles(1, 0, |i| i);\n        assert!(result.is_err());\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/tiled.rs",
    "content": "use std::collections::HashMap;\nuse std::path::Path;\nuse std::sync::Arc;\n\nuse ndarray::{Array4, ArrayD, IxDyn, s};\nuse rayon::prelude::*;\n\nuse super::tensor_store::TensorStore;\nuse crate::backend::jstprove::JstproveBackend;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::execution::{ExecutionInfo, ExecutionMethod, TileResult};\nuse crate::schema::tiling::TilingInfo;\nuse crate::slicer::onnx_proto::TensorProto;\nuse crate::utils::paths::resolve_relative_path;\n\nuse super::runner::{\n    RunConfig, extract_initializers_from_map, extract_onnx_initializers,\n    resolve_circuit_path_optional, run_onnx_inference,\n};\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn execute_tiled(\n    slices_dir: &Path,\n    slice_run_dir: &Path,\n    slice_id: &str,\n    tiling: &TilingInfo,\n    slice_circuit_path: Option<&Path>,\n    tensor_cache: &TensorStore,\n    backend: &JstproveBackend,\n    config: &RunConfig,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<crate::schema::execution::StrategyOutput> {\n    let all_names = tiling.all_input_names();\n    let multi_input = all_names.len() > 1;\n    let is_fixed_segment = tiling.ndim == 1;\n    let is_1d = tiling.ndim == 3;\n\n    let all_tiles_dyn = if is_fixed_segment {\n        prepare_fixed_segments_from_cache(tiling, tensor_cache)?\n    } else {\n        prepare_tiles_from_cache(tiling, tensor_cache, is_1d)?\n    };\n\n    let num_tiles = all_tiles_dyn[0].len();\n\n    tracing::info!(\n        slice = %slice_id,\n        num_tiles,\n        tile_size = tiling.tile_size,\n        ndim = tiling.ndim,\n        \"splitting into tiles\"\n    );\n\n    let tile_infos = tiling.tiles.as_deref().unwrap_or(&[]);\n    let single_tile = tiling.tile.as_ref();\n\n    if tile_infos.is_empty() && single_tile.is_none() {\n        return Err(DsperseError::Pipeline(format!(\n            \"tiling for '{}' has neither tile list nor single tile template\",\n            tiling.output_name\n        )));\n    }\n\n    let first_tile_info = tile_infos.first().or(single_tile);\n    let first_tile_onnx = first_tile_info\n        .map(|ti| resolve_relative_path(slices_dir, &ti.path))\n        .transpose()?;\n\n    let warm_model = if multi_input || is_1d || is_fixed_segment {\n        None\n    } else {\n        match (first_tile_onnx.as_deref(), all_tiles_dyn[0].first()) {\n            (Some(onnx_path), Some(sample)) => {\n                let shape = sample.shape().to_vec();\n                let model = crate::backend::onnx::WarmModel::load(onnx_path, &shape)?;\n                tracing::info!(slice = %slice_id, \"loaded ONNX model\");\n                Some(model)\n            }\n            _ => None,\n        }\n    };\n\n    let circuit_path = resolve_circuit_path_optional(\n        slices_dir,\n        first_tile_info.and_then(|ti| ti.jstprove_circuit_path.as_deref()),\n    )?\n    .or_else(|| slice_circuit_path.map(|p| p.to_path_buf()));\n\n    let warm_circuit = match (&circuit_path, &first_tile_onnx) {\n        (Some(cp), Some(onnx_path)) => {\n            let params = backend.load_params(cp)?;\n            let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs);\n\n            if donor_init_map.is_some() && !is_wai {\n                return Err(DsperseError::Pipeline(format!(\n                    \"{slice_id}: consumer weights require circuits compiled with --weights-as-inputs\"\n                )));\n            }\n\n            let initializers = if is_wai {\n                if let Some(map) = donor_init_map {\n                    extract_initializers_from_map(map, params.as_ref().unwrap())?\n                } else {\n                    extract_onnx_initializers(onnx_path, params.as_ref().unwrap())?\n                }\n            } else {\n                vec![]\n            };\n            let wc = crate::backend::jstprove::WarmCircuit::load(cp, initializers, backend)?;\n            tracing::info!(slice = %slice_id, wai = is_wai, \"loaded circuit bundle\");\n            Some(wc)\n        }\n        _ => None,\n    };\n\n    let warm_model = warm_model.map(Arc::new);\n    let warm_circuit = warm_circuit.map(Arc::new);\n    let circuit_path = circuit_path.map(Arc::from);\n\n    let pool = rayon::ThreadPoolBuilder::new()\n        .num_threads(config.parallel)\n        .build()\n        .map_err(|e| DsperseError::Pipeline(format!(\"thread pool: {e}\")))?;\n\n    let tile_input_names: Vec<String> = if all_names.len() > 1 {\n        (0..all_names.len())\n            .map(|i| format!(\"tile_in_{i}\"))\n            .collect()\n    } else {\n        vec![\"tile_in\".to_string()]\n    };\n\n    let collected: Vec<(TileResult, Option<ArrayD<f64>>)> = pool.install(|| {\n        (0..num_tiles)\n            .into_par_iter()\n            .map(|tile_idx| {\n                let start = std::time::Instant::now();\n                let tile_dir = slice_run_dir.join(format!(\"tile_{tile_idx}\"));\n                if let Err(e) = std::fs::create_dir_all(&tile_dir) {\n                    return (\n                        TileResult::failure(\n                            tile_idx,\n                            format!(\"mkdir: {e}\"),\n                            None,\n                            start.elapsed().as_secs_f64(),\n                        ),\n                        None,\n                    );\n                }\n\n                let tile_info = tile_infos.get(tile_idx).or(single_tile);\n                let tile_dyn = all_tiles_dyn[0][tile_idx].clone();\n\n                let per_tile_onnx = tile_info\n                    .map(|ti| resolve_relative_path(slices_dir, &ti.path))\n                    .transpose();\n                let per_tile_onnx = match per_tile_onnx {\n                    Ok(p) => p,\n                    Err(e) => {\n                        return (\n                            TileResult::failure(\n                                tile_idx,\n                                format!(\"resolve tile path: {e}\"),\n                                None,\n                                start.elapsed().as_secs_f64(),\n                            ),\n                            None,\n                        );\n                    }\n                };\n                let effective_tile_onnx_ref = per_tile_onnx.as_deref();\n\n                if tile_info.is_none() {\n                    return (\n                        TileResult::failure(\n                            tile_idx,\n                            \"no tile circuit info\".into(),\n                            None,\n                            start.elapsed().as_secs_f64(),\n                        ),\n                        None,\n                    );\n                }\n\n                let tile_output = if multi_input || is_1d || is_fixed_segment {\n                    if let Some(onnx) = effective_tile_onnx_ref {\n                        let inputs: Vec<(&str, Vec<f64>, Vec<usize>)> = all_tiles_dyn\n                            .iter()\n                            .zip(tile_input_names.iter())\n                            .map(|(input_tiles, tile_name)| {\n                                let t = &input_tiles[tile_idx];\n                                let shape: Vec<usize> = t.shape().to_vec();\n                                let data: Vec<f64> = t.iter().copied().collect();\n                                (tile_name.as_str(), data, shape)\n                            })\n                            .collect();\n                        crate::backend::onnx::run_inference_multi_named(onnx, &inputs).and_then(\n                            |named| {\n                                let (data, shape) =\n                                    named.into_values().next().ok_or_else(|| {\n                                        DsperseError::Pipeline(\n                                            \"multi-input tile produced no output\".into(),\n                                        )\n                                    })?;\n                                ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| {\n                                    DsperseError::Pipeline(format!(\n                                        \"multi-input tile output reshape: {e}\"\n                                    ))\n                                })\n                            },\n                        )\n                    } else {\n                        Err(DsperseError::Pipeline(format!(\n                            \"tile {tile_idx}: no ONNX model available for inference\"\n                        )))\n                    }\n                } else if let Some(ref wm) = warm_model {\n                    let input_flat: Vec<f64> = tile_dyn.iter().copied().collect();\n                    wm.run(&input_flat).and_then(|(data, shape)| {\n                        ArrayD::from_shape_vec(IxDyn(&shape), data).map_err(|e| {\n                            crate::error::DsperseError::Pipeline(format!(\n                                \"warm model output reshape: {e}\"\n                            ))\n                        })\n                    })\n                } else if let Some(onnx) = effective_tile_onnx_ref {\n                    run_onnx_inference(onnx, &tile_dyn)\n                } else {\n                    Err(DsperseError::Pipeline(format!(\n                        \"tile {tile_idx}: no ONNX model available for inference\"\n                    )))\n                };\n\n                let output_tensor = match tile_output {\n                    Ok(t) => t,\n                    Err(e) => {\n                        return (\n                            TileResult::failure(\n                                tile_idx,\n                                format!(\"onnx inference: {e}\"),\n                                Some(ExecutionMethod::OnnxOnly),\n                                start.elapsed().as_secs_f64(),\n                            ),\n                            None,\n                        );\n                    }\n                };\n\n                if circuit_path.is_none() {\n                    return (\n                        TileResult::success(\n                            tile_idx,\n                            Some(ExecutionMethod::OnnxOnly),\n                            start.elapsed().as_secs_f64(),\n                        ),\n                        Some(output_tensor),\n                    );\n                }\n\n                let flat: Vec<f64> = flatten_tile_inputs(&all_tiles_dyn, tile_idx);\n                let witness_result = if let Some(ref wc) = warm_circuit {\n                    wc.witness_f64(&flat)\n                } else {\n                    let cp = circuit_path\n                        .as_ref()\n                        .expect(\"circuit_path is Some: guarded by early return\");\n                    backend.witness_f64(cp, &flat, &[])\n                };\n\n                match witness_result {\n                    Ok(witness_bytes) => {\n                        let witness_path = tile_dir.join(crate::utils::paths::WITNESS_FILE);\n                        if let Err(e) = std::fs::write(&witness_path, &witness_bytes) {\n                            return (\n                                TileResult::failure(\n                                    tile_idx,\n                                    format!(\"write witness: {e}\"),\n                                    Some(ExecutionMethod::JstproveGenWitness),\n                                    start.elapsed().as_secs_f64(),\n                                ),\n                                None,\n                            );\n                        }\n                        (\n                            TileResult::success(\n                                tile_idx,\n                                Some(ExecutionMethod::JstproveGenWitness),\n                                start.elapsed().as_secs_f64(),\n                            ),\n                            Some(output_tensor),\n                        )\n                    }\n                    Err(e) => (\n                        TileResult::failure(\n                            tile_idx,\n                            e.to_string(),\n                            Some(ExecutionMethod::JstproveGenWitness),\n                            start.elapsed().as_secs_f64(),\n                        ),\n                        None,\n                    ),\n                }\n            })\n            .collect()\n    });\n\n    let mut tile_results: Vec<TileResult> = Vec::with_capacity(collected.len());\n    let mut tile_outputs: Vec<ArrayD<f64>> = Vec::with_capacity(collected.len());\n    for (result, output) in collected {\n        if let Some(o) = output {\n            tile_outputs.push(o);\n        }\n        tile_results.push(result);\n    }\n\n    if tile_results.is_empty() {\n        return Err(DsperseError::Pipeline(format!(\n            \"tiling produced zero tiles for '{}'\",\n            tiling.output_name\n        )));\n    }\n\n    let all_success = tile_results.iter().all(|r| r.success);\n\n    if !all_success {\n        let failed: Vec<_> = tile_results\n            .iter()\n            .filter(|r| !r.success)\n            .map(|r| format!(\"tile {}: {}\", r.tile_idx, r.error.as_deref().unwrap_or(\"?\")))\n            .collect();\n        return Err(DsperseError::Pipeline(format!(\n            \"tiled execution failed for '{}': {}\",\n            tiling.output_name,\n            failed.join(\"; \")\n        )));\n    }\n\n    debug_assert!(\n        !tile_outputs.is_empty(),\n        \"all tiles reported success but no outputs for '{}'\",\n        tiling.output_name\n    );\n    let reconstructed = if is_fixed_segment {\n        reconstruct_from_fixed_segments(&tile_outputs, tiling)?\n    } else if is_1d {\n        let r = reconstruct_from_tiles_1d(&tile_outputs, tiling)?;\n        trim_to_original_seq(r, tiling)?\n    } else {\n        let r = reconstruct_from_tiles(&tile_outputs, tiling)?;\n        trim_to_original_dims(r, tiling)?\n    };\n    Ok(crate::schema::execution::StrategyOutput {\n        info: ExecutionInfo {\n            method: ExecutionMethod::Tiled,\n            success: true,\n            error: None,\n            witness_file: None,\n            tile_exec_infos: tile_results,\n        },\n        outputs: vec![(tiling.output_name.clone(), reconstructed)],\n    })\n}\n\n/// Witness-only tiled execution for combined inference mode.\n///\n/// The full-model ONNX inference has already run and populated the tensor\n/// cache with all intermediate activations. This function splits those\n/// cached activations into tiles, generates per-tile ZK witnesses via the\n/// circuit backend, and returns tile-level execution results. It does NOT\n/// reconstruct output tensors — those already exist in the cache from the\n/// monolithic inference pass — hence the empty `outputs` vec in the\n/// returned `StrategyOutput`.\n#[allow(clippy::too_many_arguments)]\npub(crate) fn execute_combined_tiled(\n    slices_dir: &Path,\n    slice_run_dir: &Path,\n    slice_id: &str,\n    tiling: &TilingInfo,\n    slice_circuit_path: Option<&str>,\n    tensor_cache: &TensorStore,\n    backend: &JstproveBackend,\n    config: &RunConfig,\n    donor_init_map: Option<&HashMap<String, &TensorProto>>,\n) -> Result<crate::schema::execution::StrategyOutput> {\n    let is_fixed_segment = tiling.ndim == 1;\n    let is_1d = tiling.ndim == 3;\n    let all_tiles_dyn = if is_fixed_segment {\n        prepare_fixed_segments_from_cache(tiling, tensor_cache)?\n    } else {\n        prepare_tiles_from_cache(tiling, tensor_cache, is_1d)?\n    };\n\n    let num_tiles = all_tiles_dyn[0].len();\n\n    tracing::info!(\n        slice = %slice_id,\n        num_tiles,\n        tile_size = tiling.tile_size,\n        \"splitting combined activations into tiles for witness generation\"\n    );\n\n    let tile_infos = tiling.tiles.as_deref().unwrap_or(&[]);\n    let single_tile = tiling.tile.as_ref();\n    let first_tile_info = tile_infos.first().or(single_tile);\n\n    let circuit_path = resolve_circuit_path_optional(\n        slices_dir,\n        first_tile_info\n            .and_then(|ti| ti.jstprove_circuit_path.as_deref())\n            .or(slice_circuit_path),\n    )?;\n\n    let circuit_path = match circuit_path {\n        Some(p) => p,\n        None => {\n            return Ok(crate::schema::execution::StrategyOutput {\n                info: ExecutionInfo {\n                    method: ExecutionMethod::Tiled,\n                    success: true,\n                    error: None,\n                    witness_file: None,\n                    tile_exec_infos: (0..num_tiles)\n                        .map(|i| TileResult::success(i, Some(ExecutionMethod::OnnxOnly), 0.0))\n                        .collect(),\n                },\n                outputs: vec![],\n            });\n        }\n    };\n\n    let first_tile_onnx = first_tile_info\n        .map(|ti| resolve_relative_path(slices_dir, &ti.path))\n        .transpose()?;\n\n    let patched_tile_onnx = match (&first_tile_onnx, donor_init_map) {\n        (Some(onnx_path), Some(map)) => Some(crate::slicer::onnx_proto::build_patched_onnx(\n            onnx_path, map,\n        )?),\n        _ => None,\n    };\n    let effective_tile_onnx = patched_tile_onnx.as_ref().map(|t| t.path().to_path_buf());\n    let effective_tile_onnx_ref = effective_tile_onnx\n        .as_deref()\n        .or(first_tile_onnx.as_deref());\n\n    let params = backend.load_params(&circuit_path)?;\n    let is_wai = params.as_ref().is_some_and(|p| p.weights_as_inputs);\n\n    if donor_init_map.is_some() && !is_wai {\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: consumer weights require circuits compiled with --weights-as-inputs\"\n        )));\n    }\n\n    let warm_circuit = match effective_tile_onnx_ref {\n        Some(onnx_path) => {\n            let initializers = if is_wai {\n                if let Some(map) = donor_init_map {\n                    extract_initializers_from_map(map, params.as_ref().unwrap())?\n                } else {\n                    extract_onnx_initializers(onnx_path, params.as_ref().unwrap())?\n                }\n            } else {\n                vec![]\n            };\n            let wc =\n                crate::backend::jstprove::WarmCircuit::load(&circuit_path, initializers, backend)?;\n            tracing::info!(slice = %slice_id, wai = is_wai, \"loaded tile circuit for combined tiling\");\n            Some(wc)\n        }\n        None => None,\n    };\n\n    let warm_circuit = warm_circuit.map(Arc::new);\n    let circuit_path = Arc::from(circuit_path);\n\n    let pool = rayon::ThreadPoolBuilder::new()\n        .num_threads(config.parallel)\n        .build()\n        .map_err(|e| DsperseError::Pipeline(format!(\"thread pool: {e}\")))?;\n\n    let collected: Vec<TileResult> = pool.install(|| {\n        (0..num_tiles)\n            .into_par_iter()\n            .map(|tile_idx| {\n                let start = std::time::Instant::now();\n                let tile_dir = slice_run_dir.join(format!(\"tile_{tile_idx}\"));\n                if let Err(e) = std::fs::create_dir_all(&tile_dir) {\n                    return TileResult::failure(\n                        tile_idx,\n                        format!(\"mkdir: {e}\"),\n                        None,\n                        start.elapsed().as_secs_f64(),\n                    );\n                }\n\n                let flat: Vec<f64> = flatten_tile_inputs(&all_tiles_dyn, tile_idx);\n\n                let witness_result = if let Some(ref wc) = warm_circuit {\n                    wc.witness_f64(&flat)\n                } else {\n                    backend.witness_f64(&circuit_path, &flat, &[])\n                };\n\n                match witness_result {\n                    Ok(witness_bytes) => {\n                        let witness_path = tile_dir.join(crate::utils::paths::WITNESS_FILE);\n                        if let Err(e) = std::fs::write(&witness_path, &witness_bytes) {\n                            return TileResult::failure(\n                                tile_idx,\n                                format!(\"write witness: {e}\"),\n                                Some(ExecutionMethod::JstproveGenWitness),\n                                start.elapsed().as_secs_f64(),\n                            );\n                        }\n                        TileResult::success(\n                            tile_idx,\n                            Some(ExecutionMethod::JstproveGenWitness),\n                            start.elapsed().as_secs_f64(),\n                        )\n                    }\n                    Err(e) => TileResult::failure(\n                        tile_idx,\n                        e.to_string(),\n                        Some(ExecutionMethod::JstproveGenWitness),\n                        start.elapsed().as_secs_f64(),\n                    ),\n                }\n            })\n            .collect()\n    });\n\n    let all_success = collected.iter().all(|r| r.success);\n    if !all_success {\n        let failed: Vec<_> = collected\n            .iter()\n            .filter(|r| !r.success)\n            .map(|r| format!(\"tile {}: {}\", r.tile_idx, r.error.as_deref().unwrap_or(\"?\")))\n            .collect();\n        return Err(DsperseError::Pipeline(format!(\n            \"{slice_id}: tiled witness generation failed: {}\",\n            failed.join(\"; \")\n        )));\n    }\n\n    tracing::info!(\n        slice = %slice_id,\n        num_tiles,\n        \"tiled witness generation from combined outputs complete\"\n    );\n\n    // No output tensors: combined mode already has activations in cache\n    // from the monolithic ONNX run. Only witness artifacts are produced here.\n    Ok(crate::schema::execution::StrategyOutput {\n        info: ExecutionInfo {\n            method: ExecutionMethod::Tiled,\n            success: true,\n            error: None,\n            witness_file: None,\n            tile_exec_infos: collected,\n        },\n        outputs: vec![],\n    })\n}\n\npub(crate) fn prepare_tiles_from_cache(\n    tiling: &TilingInfo,\n    tensor_cache: &TensorStore,\n    is_1d: bool,\n) -> Result<Vec<Vec<ArrayD<f64>>>> {\n    let all_names = tiling.all_input_names();\n    let mut all_tiles: Vec<Vec<ArrayD<f64>>> = Vec::with_capacity(all_names.len());\n    for name in &all_names {\n        let input_arr = tensor_cache.get(name)?.clone();\n        if is_1d {\n            let tiles = split_into_tiles_1d(&input_arr, tiling)?;\n            all_tiles.push(tiles);\n        } else {\n            let input_4d = if input_arr.ndim() == 4 {\n                let s = input_arr.shape();\n                Array4::from_shape_vec(\n                    (s[0], s[1], s[2], s[3]),\n                    input_arr.iter().copied().collect(),\n                )\n                .map_err(|e| DsperseError::Pipeline(format!(\"tiling input reshape: {e}\")))?\n            } else {\n                let input_flat: Vec<f64> = input_arr.iter().copied().collect();\n                let h = if tiling.h > 0 {\n                    tiling.h\n                } else {\n                    tiling.tiles_y * tiling.tile_size\n                };\n                let w = if tiling.w > 0 {\n                    tiling.w\n                } else {\n                    tiling.tiles_x * tiling.tile_size\n                };\n                reshape_to_4d(&input_flat, tiling.c_in, h, w)?\n            };\n            let tiles = split_into_tiles(&input_4d, tiling)?;\n            all_tiles.push(tiles.into_iter().map(|t| t.into_dyn()).collect());\n        }\n    }\n    Ok(all_tiles)\n}\n\npub fn split_for_tiling(input: &ArrayD<f64>, tiling: &TilingInfo) -> Result<Vec<ArrayD<f64>>> {\n    let is_fixed_segment = tiling.ndim == 1;\n    if is_fixed_segment {\n        let segment_size = tiling.segment_size.ok_or_else(|| {\n            DsperseError::Pipeline(\"split_for_tiling: fixed segment missing segment_size\".into())\n        })?;\n        if segment_size == 0 {\n            return Err(DsperseError::Pipeline(\n                \"split_for_tiling: segment_size must be > 0\".into(),\n            ));\n        }\n        let total_elements = tiling.total_elements.ok_or_else(|| {\n            DsperseError::Pipeline(\"split_for_tiling: fixed segment missing total_elements\".into())\n        })?;\n        let flat: Vec<f64> = input.iter().copied().collect();\n        if flat.len() < total_elements {\n            return Err(DsperseError::Pipeline(format!(\n                \"split_for_tiling: input has {} elements, expected at least {}\",\n                flat.len(),\n                total_elements\n            )));\n        }\n        let num_segments = total_elements.div_ceil(segment_size);\n        let mut segments = Vec::with_capacity(num_segments);\n        for i in 0..num_segments {\n            let start = i * segment_size;\n            if start >= flat.len() {\n                break;\n            }\n            let end = (start + segment_size).min(total_elements);\n            let mut seg_data = vec![0.0f64; segment_size];\n            seg_data[..end - start].copy_from_slice(&flat[start..end]);\n            segments.push(\n                ArrayD::from_shape_vec(IxDyn(&[segment_size]), seg_data)\n                    .map_err(|e| DsperseError::Pipeline(format!(\"segment reshape: {e}\")))?,\n            );\n        }\n        return Ok(segments);\n    }\n    let is_1d = tiling.ndim == 3;\n    if is_1d {\n        return split_into_tiles_1d(input, tiling);\n    }\n    let input_4d = if input.ndim() == 4 {\n        let s = input.shape();\n        Array4::from_shape_vec((s[0], s[1], s[2], s[3]), input.iter().copied().collect())\n            .map_err(|e| DsperseError::Pipeline(format!(\"tiling input reshape: {e}\")))?\n    } else {\n        let flat: Vec<f64> = input.iter().copied().collect();\n        let h = if tiling.h > 0 {\n            tiling.h\n        } else {\n            tiling.tiles_y * tiling.tile_size\n        };\n        let w = if tiling.w > 0 {\n            tiling.w\n        } else {\n            tiling.tiles_x * tiling.tile_size\n        };\n        reshape_to_4d(&flat, tiling.c_in, h, w)?\n    };\n    let tiles = split_into_tiles(&input_4d, tiling)?;\n    Ok(tiles.into_iter().map(|t| t.into_dyn()).collect())\n}\n\npub fn split_into_tiles(input: &Array4<f64>, tiling: &TilingInfo) -> Result<Vec<Array4<f64>>> {\n    if tiling.halo.iter().any(|&v| v < 0) {\n        return Err(DsperseError::Pipeline(format!(\n            \"negative halo values not supported: halo={:?}\",\n            tiling.halo\n        )));\n    }\n    let (n, c, h, w) = input.dim();\n    if n != 1 {\n        return Err(DsperseError::Pipeline(format!(\n            \"split_into_tiles: batch size {n} not supported, expected 1\"\n        )));\n    }\n    let halo_top = tiling.halo[0] as usize;\n    let halo_left = tiling.halo[1] as usize;\n    let halo_bottom = tiling.halo[2] as usize;\n    let halo_right = tiling.halo[3] as usize;\n    let tile_h = tiling.tile_size + halo_top + halo_bottom;\n    let tile_w = tiling.tile_size + halo_left + halo_right;\n\n    let padded_h = tiling.tiles_y * tiling.tile_size + halo_top + halo_bottom;\n    let padded_w = tiling.tiles_x * tiling.tile_size + halo_left + halo_right;\n    if halo_top + h > padded_h || halo_left + w > padded_w {\n        return Err(DsperseError::Pipeline(format!(\n            \"split_into_tiles: input spatial ({h}x{w}) exceeds padded grid ({padded_h}x{padded_w})\"\n        )));\n    }\n    let mut padded = Array4::<f64>::zeros((n, c, padded_h, padded_w));\n    padded\n        .slice_mut(s![.., .., halo_top..halo_top + h, halo_left..halo_left + w])\n        .assign(input);\n\n    let mut tiles = Vec::new();\n    for ty in 0..tiling.tiles_y {\n        for tx in 0..tiling.tiles_x {\n            let y_start = ty * tiling.tile_size;\n            let x_start = tx * tiling.tile_size;\n            let tile = padded\n                .slice(s![\n                    ..,\n                    ..,\n                    y_start..y_start + tile_h,\n                    x_start..x_start + tile_w\n                ])\n                .to_owned();\n            tiles.push(tile);\n        }\n    }\n\n    Ok(tiles)\n}\n\npub fn reconstruct_from_tiles(\n    tile_outputs: &[ArrayD<f64>],\n    tiling: &TilingInfo,\n) -> Result<ArrayD<f64>> {\n    let expected_tiles = tiling.tiles_y * tiling.tiles_x;\n    if tile_outputs.len() != expected_tiles {\n        return Err(DsperseError::Pipeline(format!(\n            \"reconstruct: expected {} tiles ({}x{}), got {}\",\n            expected_tiles,\n            tiling.tiles_y,\n            tiling.tiles_x,\n            tile_outputs.len()\n        )));\n    }\n\n    let out_h = tiling.out_tile[0].max(1) as usize;\n    let out_w = tiling.out_tile[1].max(1) as usize;\n    let c_out = tiling.c_out;\n    let total_h = out_h * tiling.tiles_y;\n    let total_w = out_w * tiling.tiles_x;\n\n    let mut output = Array4::<f64>::zeros((1, c_out, total_h, total_w));\n\n    for (idx, tile_arr) in tile_outputs.iter().enumerate() {\n        let ty = idx / tiling.tiles_x;\n        let tx = idx % tiling.tiles_x;\n\n        let tile_flat: Vec<f64> = tile_arr.iter().copied().collect();\n        if tile_flat.is_empty() {\n            return Err(DsperseError::Pipeline(format!(\n                \"tile ({},{}) marked successful but produced no data\",\n                ty, tx\n            )));\n        }\n\n        let tile_elements = c_out * out_h * out_w;\n        if tile_flat.len() != tile_elements {\n            return Err(DsperseError::Pipeline(format!(\n                \"tile ({},{}) has {} elements, expected {} (c_out={}, out_h={}, out_w={})\",\n                ty,\n                tx,\n                tile_flat.len(),\n                tile_elements,\n                c_out,\n                out_h,\n                out_w\n            )));\n        }\n\n        let tile_4d = Array4::from_shape_vec((1, c_out, out_h, out_w), tile_flat.to_vec())\n            .map_err(|e| {\n                DsperseError::Pipeline(format!(\"tile ({},{}) reshape failed: {e}\", ty, tx))\n            })?;\n        let y_start = ty * out_h;\n        let x_start = tx * out_w;\n        output\n            .slice_mut(s![\n                ..,\n                ..,\n                y_start..y_start + out_h,\n                x_start..x_start + out_w\n            ])\n            .assign(&tile_4d);\n    }\n\n    Ok(output.into_dyn())\n}\n\npub(crate) fn trim_to_original_dims(arr: ArrayD<f64>, tiling: &TilingInfo) -> Result<ArrayD<f64>> {\n    if tiling.h == 0 || tiling.w == 0 {\n        return Ok(arr);\n    }\n    let stride_h = tiling.stride[0].max(1) as usize;\n    let stride_w = tiling.stride[1].max(1) as usize;\n    let expected_h = tiling.h / stride_h;\n    let expected_w = tiling.w / stride_w;\n    let grid_h = tiling.out_tile[0].max(1) as usize * tiling.tiles_y;\n    let grid_w = tiling.out_tile[1].max(1) as usize * tiling.tiles_x;\n    if grid_h > expected_h || grid_w > expected_w {\n        if arr.ndim() != 4 {\n            return Err(DsperseError::Pipeline(format!(\n                \"trim_to_original_dims: expected 4D array, got {}D\",\n                arr.ndim()\n            )));\n        }\n        Ok(arr\n            .slice(s![.., .., ..expected_h, ..expected_w])\n            .to_owned()\n            .into_dyn())\n    } else {\n        Ok(arr)\n    }\n}\n\npub(crate) fn split_into_tiles_1d(\n    input: &ArrayD<f64>,\n    tiling: &TilingInfo,\n) -> Result<Vec<ArrayD<f64>>> {\n    let shape = input.shape();\n    if shape.len() != 3 {\n        return Err(DsperseError::Pipeline(format!(\n            \"split_into_tiles_1d: expected 3D input, got {}D\",\n            shape.len()\n        )));\n    }\n    let (n, seq, _hidden) = (shape[0], shape[1], shape[2]);\n    if n != 1 {\n        return Err(DsperseError::Pipeline(format!(\n            \"split_into_tiles_1d: batch size {n} not supported, expected 1\"\n        )));\n    }\n    let tile_size = tiling.tile_size;\n    if tile_size == 0 || tiling.tiles_y == 0 {\n        return Err(DsperseError::Pipeline(format!(\n            \"split_into_tiles_1d: invalid tiling config tile_size={}, tiles_y={}\",\n            tile_size, tiling.tiles_y\n        )));\n    }\n    let padded_seq = tiling\n        .tiles_y\n        .checked_mul(tile_size)\n        .ok_or_else(|| DsperseError::Pipeline(\"split_into_tiles_1d: padded_seq overflow\".into()))?;\n    if seq > padded_seq {\n        return Err(DsperseError::Pipeline(format!(\n            \"split_into_tiles_1d: input seq {seq} exceeds padded seq {padded_seq}\"\n        )));\n    }\n    let mut padded = ArrayD::<f64>::zeros(vec![n, padded_seq, shape[2]]);\n    padded.slice_mut(s![.., ..seq, ..]).assign(input);\n\n    let mut tiles = Vec::with_capacity(tiling.tiles_y);\n    for ty in 0..tiling.tiles_y {\n        let start = ty * tile_size;\n        let tile = padded\n            .slice(s![.., start..start + tile_size, ..])\n            .to_owned()\n            .into_dyn();\n        tiles.push(tile);\n    }\n    Ok(tiles)\n}\n\npub(crate) fn reconstruct_from_tiles_1d(\n    tile_outputs: &[ArrayD<f64>],\n    tiling: &TilingInfo,\n) -> Result<ArrayD<f64>> {\n    if tile_outputs.is_empty() {\n        return Err(DsperseError::Pipeline(\n            \"reconstruct_1d: no tile outputs\".into(),\n        ));\n    }\n    if tile_outputs.len() != tiling.tiles_y {\n        return Err(DsperseError::Pipeline(format!(\n            \"reconstruct_1d: expected {} tiles, got {}\",\n            tiling.tiles_y,\n            tile_outputs.len()\n        )));\n    }\n    let first = &tile_outputs[0];\n    if first.ndim() != 3 {\n        return Err(DsperseError::Pipeline(format!(\n            \"reconstruct_1d: expected 3D tiles, got {}D\",\n            first.ndim()\n        )));\n    }\n    let fshape = first.shape();\n    let (tile_len, hidden) = (fshape[1], fshape[2]);\n    let total_seq = tile_len * tile_outputs.len();\n    let mut output = ArrayD::<f64>::zeros(vec![1, total_seq, hidden]);\n    for (idx, tile) in tile_outputs.iter().enumerate() {\n        if tile.shape() != fshape {\n            return Err(DsperseError::Pipeline(format!(\n                \"reconstruct_1d: tile {idx} shape {:?} != first tile shape {:?}\",\n                tile.shape(),\n                fshape\n            )));\n        }\n        let start = idx * tile_len;\n        output\n            .slice_mut(s![.., start..start + tile_len, ..])\n            .assign(tile);\n    }\n    Ok(output)\n}\n\npub(crate) fn trim_to_original_seq(arr: ArrayD<f64>, tiling: &TilingInfo) -> Result<ArrayD<f64>> {\n    if tiling.h == 0 {\n        return Ok(arr);\n    }\n    if arr.ndim() != 3 {\n        return Err(DsperseError::Pipeline(format!(\n            \"trim_to_original_seq: expected 3D array, got {}D\",\n            arr.ndim()\n        )));\n    }\n    let current_seq = arr.shape()[1];\n    if current_seq > tiling.h {\n        Ok(arr.slice(s![.., ..tiling.h, ..]).to_owned().into_dyn())\n    } else {\n        Ok(arr)\n    }\n}\n\npub(crate) fn prepare_fixed_segments_from_cache(\n    tiling: &TilingInfo,\n    tensor_cache: &TensorStore,\n) -> Result<Vec<Vec<ArrayD<f64>>>> {\n    let segment_size = tiling.segment_size.ok_or_else(|| {\n        DsperseError::Pipeline(\"fixed segment tiling missing segment_size\".into())\n    })?;\n    if segment_size == 0 {\n        return Err(DsperseError::Pipeline(\n            \"fixed segment tiling has segment_size=0\".into(),\n        ));\n    }\n    let total_elements = tiling.total_elements.ok_or_else(|| {\n        DsperseError::Pipeline(\"fixed segment tiling missing total_elements\".into())\n    })?;\n    let all_names = tiling.all_input_names();\n    let num_segments = total_elements.div_ceil(segment_size);\n    let mut all_segments: Vec<Vec<ArrayD<f64>>> = Vec::with_capacity(all_names.len());\n    for name in &all_names {\n        let input_arr = tensor_cache.get(name)?.clone();\n        let flat: Vec<f64> = input_arr.iter().copied().collect();\n        if flat.len() < total_elements {\n            return Err(DsperseError::Pipeline(format!(\n                \"fixed segment: input '{}' has {} elements, expected at least {}\",\n                name,\n                flat.len(),\n                total_elements\n            )));\n        }\n        let mut segments = Vec::with_capacity(num_segments);\n        for i in 0..num_segments {\n            let start = i * segment_size;\n            let end = (start + segment_size).min(total_elements);\n            let mut seg_data = vec![0.0f64; segment_size];\n            seg_data[..end - start].copy_from_slice(&flat[start..end]);\n            let seg = ArrayD::from_shape_vec(IxDyn(&[segment_size]), seg_data)\n                .map_err(|e| DsperseError::Pipeline(format!(\"fixed segment reshape: {e}\")))?;\n            segments.push(seg);\n        }\n        all_segments.push(segments);\n    }\n    Ok(all_segments)\n}\n\npub(crate) fn reconstruct_from_fixed_segments(\n    segment_outputs: &[ArrayD<f64>],\n    tiling: &TilingInfo,\n) -> Result<ArrayD<f64>> {\n    let total_elements = tiling.total_elements.ok_or_else(|| {\n        DsperseError::Pipeline(\"reconstruct fixed segments: missing total_elements\".into())\n    })?;\n    if segment_outputs.is_empty() {\n        return Err(DsperseError::Pipeline(\n            \"reconstruct fixed segments: no outputs\".into(),\n        ));\n    }\n    let mut flat = Vec::with_capacity(total_elements);\n    for seg in segment_outputs {\n        flat.extend(seg.iter().copied());\n    }\n    flat.truncate(total_elements);\n    let shape: Vec<usize> = if tiling.original_shape.is_empty() {\n        vec![total_elements]\n    } else {\n        tiling.original_shape.iter().map(|&d| d as usize).collect()\n    };\n    ArrayD::from_shape_vec(IxDyn(&shape), flat)\n        .map_err(|e| DsperseError::Pipeline(format!(\"reconstruct fixed segments reshape: {e}\")))\n}\n\npub(crate) fn reshape_to_4d(flat: &[f64], c: usize, h: usize, w: usize) -> Result<Array4<f64>> {\n    let n = 1usize;\n    let total = flat.len();\n    if n * c * h * w != total {\n        return Err(DsperseError::Pipeline(format!(\n            \"cannot reshape {total} elements to 4D (n={n}, c={c}, h={h}, w={w})\"\n        )));\n    }\n    Array4::from_shape_vec((n, c, h, w), flat.to_vec())\n        .map_err(|e| DsperseError::Pipeline(format!(\"reshape: {e}\")))\n}\n\npub(crate) fn flatten_tile_inputs(all_tiles: &[Vec<ArrayD<f64>>], tile_idx: usize) -> Vec<f64> {\n    let total: usize = all_tiles.iter().map(|tiles| tiles[tile_idx].len()).sum();\n    let mut flat = Vec::with_capacity(total);\n    for input_tiles in all_tiles {\n        flat.extend(input_tiles[tile_idx].iter().copied());\n    }\n    flat\n}\n"
  },
  {
    "path": "crates/dsperse/src/pipeline/verifier.rs",
    "content": "use std::path::Path;\n\nuse crate::backend::ProofBackend;\nuse crate::error::Result;\nuse crate::schema::execution::RunMetadata;\n\nuse super::stage::{PipelineStage, run_pipeline_stage};\n\npub fn verify_run(\n    run_dir: &Path,\n    slices_dir: &Path,\n    backend: &dyn ProofBackend,\n    parallel: usize,\n) -> Result<RunMetadata> {\n    run_pipeline_stage(\n        PipelineStage::Verify,\n        run_dir,\n        slices_dir,\n        backend,\n        parallel,\n    )\n}\n"
  },
  {
    "path": "crates/dsperse/src/python.rs",
    "content": "use std::path::PathBuf;\n\nuse pyo3::exceptions::PyRuntimeError;\nuse pyo3::prelude::*;\n\nuse crate::backend::jstprove::JstproveBackend;\nuse crate::error::DsperseError;\nuse crate::pipeline::{self, RunConfig};\n\nuse jstprove_circuits::api::{ProofSystemParseError, ProofSystemType as ProofSystem};\n\nfn to_py_err(e: DsperseError) -> PyErr {\n    let msg = e.to_string();\n    match e {\n        DsperseError::Io { .. } => pyo3::exceptions::PyIOError::new_err(msg),\n        DsperseError::MsgpackEncode(_) | DsperseError::MsgpackDecode(_) => {\n            pyo3::exceptions::PyValueError::new_err(msg)\n        }\n        DsperseError::Archive(_) | DsperseError::Metadata(_) => {\n            pyo3::exceptions::PyValueError::new_err(msg)\n        }\n        DsperseError::Onnx(_)\n        | DsperseError::Backend(_)\n        | DsperseError::Slicer(_)\n        | DsperseError::Pipeline(_)\n        | DsperseError::Other(_) => PyRuntimeError::new_err(msg),\n    }\n}\n\nfn to_pretty_json<T: serde::Serialize>(value: &T) -> PyResult<String> {\n    serde_json::to_string_pretty(value).map_err(|e| {\n        to_py_err(DsperseError::Other(format!(\n            \"pretty-json serialization failed: {e}\"\n        )))\n    })\n}\n\nfn resolve_ops(proof_system: &str, circuit_ops: Option<&[String]>) -> PyResult<Vec<String>> {\n    let ps: ProofSystem = proof_system\n        .parse()\n        .map_err(|e: ProofSystemParseError| PyRuntimeError::new_err(e.to_string()))?;\n    let supported = ps.supported_ops();\n    match circuit_ops {\n        None => Ok(supported.iter().map(|s| (*s).to_string()).collect()),\n        Some(ops) => {\n            for op in ops {\n                if !supported.contains(&op.as_str()) {\n                    return Err(PyRuntimeError::new_err(format!(\n                        \"op {op:?} not supported by proof system {ps}. Supported: {supported:?}\"\n                    )));\n                }\n            }\n            Ok(ops.to_vec())\n        }\n    }\n}\n\nfn require_nonzero(parallel: usize) -> PyResult<()> {\n    if parallel == 0 {\n        return Err(pyo3::exceptions::PyValueError::new_err(\n            \"parallel must be > 0\",\n        ));\n    }\n    Ok(())\n}\n\n#[pyfunction]\n#[pyo3(signature = (model_path, output_dir=None, tile_size=None, proof_system=\"expander\", circuit_ops=None, input_shape=None))]\nfn slice_model(\n    py: Python<'_>,\n    model_path: &str,\n    output_dir: Option<&str>,\n    tile_size: Option<usize>,\n    proof_system: &str,\n    circuit_ops: Option<Vec<String>>,\n    input_shape: Option<Vec<i64>>,\n) -> PyResult<String> {\n    let model = PathBuf::from(model_path);\n    let out = output_dir.map(PathBuf::from);\n    let ops = resolve_ops(proof_system, circuit_ops.as_deref())?;\n    let ops_refs: Vec<&str> = ops.iter().map(String::as_str).collect();\n    let metadata = py\n        .allow_threads(|| {\n            crate::slicer::slice_model(\n                &model,\n                out.as_deref(),\n                tile_size,\n                &ops_refs,\n                input_shape.as_deref(),\n            )\n        })\n        .map_err(to_py_err)?;\n    to_pretty_json(&metadata)\n}\n\n#[pyfunction]\n#[allow(clippy::too_many_arguments)]\n#[pyo3(signature = (slices_dir, proof_config=\"bn254_raw\", parallel=1, weights_as_inputs=true, layers=None, proof_system=\"expander\", circuit_ops=None, skip_compile_over_size=None, holographic=false))]\nfn compile_slices(\n    py: Python<'_>,\n    slices_dir: &str,\n    proof_config: &str,\n    parallel: usize,\n    weights_as_inputs: bool,\n    layers: Option<Vec<usize>>,\n    proof_system: &str,\n    circuit_ops: Option<Vec<String>>,\n    skip_compile_over_size: Option<u64>,\n    holographic: bool,\n) -> PyResult<()> {\n    require_nonzero(parallel)?;\n    let backend = JstproveBackend::default();\n    let parsed_config: jstprove_circuits::api::ProofConfigType =\n        proof_config\n            .parse()\n            .map_err(|e: jstprove_circuits::api::ProofConfigError| {\n                pyo3::exceptions::PyValueError::new_err(e.to_string())\n            })?;\n    let dir = PathBuf::from(slices_dir);\n    let ops = resolve_ops(proof_system, circuit_ops.as_deref())?;\n    let ops_refs: Vec<&str> = ops.iter().map(String::as_str).collect();\n    let report = py\n        .allow_threads(|| {\n            pipeline::compile_slices(\n                &dir,\n                &backend,\n                parsed_config,\n                parallel,\n                weights_as_inputs,\n                layers.as_deref(),\n                &ops_refs,\n                skip_compile_over_size,\n                holographic,\n            )\n        })\n        .map_err(to_py_err)?;\n    // Propagate partial-compile failures to the Python caller so\n    // silent non-zero masks become impossible, but phrase the\n    // message in Python-binding terms rather than reusing the\n    // CLI's --allow-onnx-fallback hint (the Python API has its\n    // own opt-in path).\n    if !report.failed.is_empty() {\n        return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(\n            \"partial compile failures: {} slice(s) failed to compile; catch this exception and continue to accept ONNX fallback for the failed slices, or inspect the Rust-side logs for the per-slice error payload\",\n            report.failed.len()\n        )));\n    }\n    Ok(())\n}\n\n#[pyfunction]\n#[allow(clippy::too_many_arguments)]\n#[pyo3(signature = (slices_dir, input_file, run_dir, parallel=1, batch=false, weights_onnx=None, combined=true))]\nfn run_inference(\n    py: Python<'_>,\n    slices_dir: &str,\n    input_file: &str,\n    run_dir: &str,\n    parallel: usize,\n    batch: bool,\n    weights_onnx: Option<&str>,\n    combined: bool,\n) -> PyResult<String> {\n    require_nonzero(parallel)?;\n    let backend = JstproveBackend::default();\n    let config = RunConfig {\n        parallel,\n        batch,\n        weights_onnx: weights_onnx.map(PathBuf::from),\n        combined,\n    };\n    let sd = PathBuf::from(slices_dir);\n    let inf = PathBuf::from(input_file);\n    let rd = PathBuf::from(run_dir);\n    let metadata = py\n        .allow_threads(|| pipeline::run_inference(&sd, &inf, &rd, &backend, &config))\n        .map_err(to_py_err)?;\n    to_pretty_json(&metadata)\n}\n\n#[pyfunction]\n#[pyo3(signature = (run_dir, slices_dir, parallel=1))]\nfn prove_run(py: Python<'_>, run_dir: &str, slices_dir: &str, parallel: usize) -> PyResult<String> {\n    require_nonzero(parallel)?;\n    let backend = JstproveBackend::default();\n    let rd = PathBuf::from(run_dir);\n    let sd = PathBuf::from(slices_dir);\n    let metadata = py\n        .allow_threads(|| pipeline::prove_run(&rd, &sd, &backend, parallel))\n        .map_err(to_py_err)?;\n    to_pretty_json(&metadata)\n}\n\n#[pyfunction]\n#[pyo3(signature = (run_dir, slices_dir, parallel=1))]\nfn verify_run(\n    py: Python<'_>,\n    run_dir: &str,\n    slices_dir: &str,\n    parallel: usize,\n) -> PyResult<String> {\n    require_nonzero(parallel)?;\n    let backend = JstproveBackend::default();\n    let rd = PathBuf::from(run_dir);\n    let sd = PathBuf::from(slices_dir);\n    let metadata = py\n        .allow_threads(|| pipeline::verify_run(&rd, &sd, &backend, parallel))\n        .map_err(to_py_err)?;\n    to_pretty_json(&metadata)\n}\n\n#[pyfunction]\n#[pyo3(signature = (argv=None))]\nfn cli_main(py: Python<'_>, argv: Option<Vec<String>>) -> PyResult<()> {\n    use clap::Parser;\n    use tracing_subscriber::EnvFilter;\n\n    let cli = match argv {\n        Some(args) => crate::cli::Cli::try_parse_from(args.clone()).or_else(|_| {\n            let mut with_prog = vec![\"dsperse\".to_string()];\n            with_prog.extend(args);\n            crate::cli::Cli::try_parse_from(with_prog)\n        }),\n        None => crate::cli::Cli::try_parse(),\n    }\n    .map_err(|e| PyRuntimeError::new_err(e.to_string()))?;\n\n    let _ = tracing_subscriber::fmt()\n        .with_env_filter(\n            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),\n        )\n        .try_init();\n\n    eprintln!(\"dsperse {}\", crate::cli::VERSION);\n\n    let result = py.allow_threads(|| crate::cli::dispatch(cli.command));\n\n    result.map_err(to_py_err)\n}\n\n#[pyfunction]\n#[pyo3(signature = (slices_dir, parallel=1, overwrite=false))]\nfn setup_holographic(\n    py: Python<'_>,\n    slices_dir: &str,\n    parallel: usize,\n    overwrite: bool,\n) -> PyResult<()> {\n    require_nonzero(parallel)?;\n    let backend = JstproveBackend::default();\n    let dir = PathBuf::from(slices_dir);\n    let report = py\n        .allow_threads(|| {\n            pipeline::setup_holographic_for_slices(&dir, &backend, parallel, overwrite)\n        })\n        .map_err(to_py_err)?;\n    if !report.failed.is_empty() {\n        return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(\n            \"{} bundle(s) failed holographic setup\",\n            report.failed.len()\n        )));\n    }\n    Ok(())\n}\n\n#[pymodule]\nfn _native(m: &Bound<'_, PyModule>) -> PyResult<()> {\n    m.add_function(wrap_pyfunction!(slice_model, m)?)?;\n    m.add_function(wrap_pyfunction!(compile_slices, m)?)?;\n    m.add_function(wrap_pyfunction!(run_inference, m)?)?;\n    m.add_function(wrap_pyfunction!(prove_run, m)?)?;\n    m.add_function(wrap_pyfunction!(verify_run, m)?)?;\n    m.add_function(wrap_pyfunction!(setup_holographic, m)?)?;\n    m.add_function(wrap_pyfunction!(cli_main, m)?)?;\n    Ok(())\n}\n"
  },
  {
    "path": "crates/dsperse/src/schema/execution.rs",
    "content": "use std::collections::HashMap;\n\nuse serde::{Deserialize, Serialize};\n\nuse super::metadata::{BackendKind, RunSliceMetadata};\n\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub enum ExecutionMethod {\n    JstproveGenWitness,\n    OnnxOnly,\n    Tiled,\n    ChannelSplit,\n    DimSplit,\n    JstproveProve,\n    JstproveVerify,\n}\n\nimpl std::fmt::Display for ExecutionMethod {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            Self::JstproveGenWitness => write!(f, \"jstprove_gen_witness\"),\n            Self::OnnxOnly => write!(f, \"onnx_only\"),\n            Self::Tiled => write!(f, \"tiled\"),\n            Self::ChannelSplit => write!(f, \"channel_split\"),\n            Self::DimSplit => write!(f, \"dim_split\"),\n            Self::JstproveProve => write!(f, \"jstprove_prove\"),\n            Self::JstproveVerify => write!(f, \"jstprove_verify\"),\n        }\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct TileResult {\n    pub tile_idx: usize,\n    pub success: bool,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub error: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub method: Option<ExecutionMethod>,\n    #[serde(default, skip_serializing_if = \"is_zero\")]\n    pub time_sec: f64,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub proof_path: Option<String>,\n}\n\nimpl TileResult {\n    pub fn failure(\n        tile_idx: usize,\n        error: String,\n        method: Option<ExecutionMethod>,\n        time_sec: f64,\n    ) -> Self {\n        Self {\n            tile_idx,\n            success: false,\n            error: Some(error),\n            method,\n            time_sec,\n            proof_path: None,\n        }\n    }\n\n    pub fn success(tile_idx: usize, method: Option<ExecutionMethod>, time_sec: f64) -> Self {\n        Self {\n            tile_idx,\n            success: true,\n            error: None,\n            method,\n            time_sec,\n            proof_path: None,\n        }\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct ExecutionInfo {\n    pub method: ExecutionMethod,\n    #[serde(default)]\n    pub success: bool,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub error: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub witness_file: Option<String>,\n    #[serde(default, skip_serializing_if = \"Vec::is_empty\", alias = \"tiles\")]\n    pub tile_exec_infos: Vec<TileResult>,\n}\n\n#[derive(Debug)]\npub struct StrategyOutput {\n    pub info: ExecutionInfo,\n    pub outputs: Vec<(String, ndarray::ArrayD<f64>)>,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct SliceResult {\n    pub slice_id: String,\n    pub success: bool,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub method: Option<ExecutionMethod>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub error: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub proof_path: Option<String>,\n    #[serde(default, skip_serializing_if = \"is_zero\")]\n    pub time_sec: f64,\n    #[serde(default, skip_serializing_if = \"Vec::is_empty\")]\n    pub tiles: Vec<TileResult>,\n}\n\nimpl SliceResult {\n    pub fn failure(\n        slice_id: impl Into<String>,\n        method: ExecutionMethod,\n        error: String,\n        time_sec: f64,\n    ) -> Self {\n        Self {\n            slice_id: slice_id.into(),\n            success: false,\n            method: Some(method),\n            error: Some(error),\n            proof_path: None,\n            time_sec,\n            tiles: Vec::new(),\n        }\n    }\n\n    pub fn success(slice_id: impl Into<String>, method: ExecutionMethod, time_sec: f64) -> Self {\n        Self {\n            slice_id: slice_id.into(),\n            success: true,\n            method: Some(method),\n            error: None,\n            proof_path: None,\n            time_sec,\n            tiles: Vec::new(),\n        }\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct ExecutionNode {\n    pub slice_id: String,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub primary: Option<String>,\n    #[serde(default)]\n    pub fallbacks: Vec<String>,\n    #[serde(default)]\n    pub use_circuit: bool,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub next: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub circuit_path: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub onnx_path: Option<String>,\n    #[serde(default)]\n    pub backend: BackendKind,\n}\n\nimpl Default for ExecutionNode {\n    fn default() -> Self {\n        Self {\n            slice_id: String::new(),\n            primary: None,\n            fallbacks: Vec::new(),\n            use_circuit: false,\n            next: None,\n            circuit_path: None,\n            onnx_path: None,\n            backend: BackendKind::Onnx,\n        }\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct ExecutionResultEntry {\n    pub slice_id: String,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub witness_execution: Option<ExecutionInfo>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub proof_execution: Option<SliceResult>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub verification_execution: Option<SliceResult>,\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct ExecutionChain {\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub head: Option<String>,\n    #[serde(default)]\n    pub nodes: HashMap<String, ExecutionNode>,\n    #[serde(default)]\n    pub fallback_map: HashMap<String, Vec<String>>,\n    #[serde(default)]\n    pub execution_results: Vec<ExecutionResultEntry>,\n    #[serde(default)]\n    pub jstprove_proved_slices: usize,\n    #[serde(default)]\n    pub jstprove_verified_slices: usize,\n}\n\nimpl ExecutionChain {\n    pub fn get_result_for_slice(&self, slice_id: &str) -> Option<&ExecutionResultEntry> {\n        self.execution_results\n            .iter()\n            .find(|e| e.slice_id == slice_id)\n    }\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct RunMetadata {\n    #[serde(default)]\n    pub slices: HashMap<String, RunSliceMetadata>,\n    #[serde(default)]\n    pub execution_chain: ExecutionChain,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub packaging_type: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub source_path: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub run_directory: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub model_path: Option<String>,\n}\n\nimpl RunMetadata {\n    pub fn get_slice(&self, slice_id: &str) -> Option<&RunSliceMetadata> {\n        self.slices.get(slice_id)\n    }\n\n    pub fn iter_circuit_slices(&self) -> impl Iterator<Item = (&str, &RunSliceMetadata)> {\n        self.execution_chain\n            .nodes\n            .iter()\n            .filter(|(_, node)| node.use_circuit)\n            .filter_map(|(slice_id, _)| {\n                self.slices\n                    .get(slice_id)\n                    .map(|meta| (slice_id.as_str(), meta))\n            })\n    }\n}\n\nfn is_zero(v: &f64) -> bool {\n    *v == 0.0\n}\n"
  },
  {
    "path": "crates/dsperse/src/schema/metadata.rs",
    "content": "use std::collections::HashMap;\n\nuse serde::{Deserialize, Serialize};\n\nuse super::tiling::{ChannelSplitInfo, DimSplitInfo, TilingInfo};\n\n#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]\n#[serde(rename_all = \"lowercase\")]\npub enum BackendKind {\n    #[serde(alias = \"JSTPROVE\")]\n    Jstprove,\n    #[default]\n    Onnx,\n}\n\nimpl std::fmt::Display for BackendKind {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            Self::Jstprove => write!(f, \"jstprove\"),\n            Self::Onnx => write!(f, \"onnx\"),\n        }\n    }\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct TensorShape {\n    #[serde(default)]\n    pub input: Vec<Vec<i64>>,\n    #[serde(default)]\n    pub output: Vec<Vec<i64>>,\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct Dependencies {\n    #[serde(default)]\n    pub input: Vec<String>,\n    #[serde(default)]\n    pub output: Vec<String>,\n    #[serde(default)]\n    pub filtered_inputs: Vec<String>,\n}\n\n#[derive(Debug, Clone, Default, Deserialize)]\npub struct CompilationFiles {\n    #[serde(default, alias = \"compiled_circuit\", alias = \"circuit\")]\n    pub compiled: Option<String>,\n    #[serde(default)]\n    pub settings: Option<String>,\n    #[serde(default)]\n    pub pk_key: Option<String>,\n    #[serde(default)]\n    pub vk_key: Option<String>,\n}\n\n#[derive(Debug, Clone, Default, Deserialize)]\npub struct BackendCompilation {\n    #[serde(default)]\n    pub compiled: bool,\n    #[serde(default)]\n    pub tiled: bool,\n    #[serde(default)]\n    pub weights_as_inputs: bool,\n    #[serde(default)]\n    pub files: CompilationFiles,\n    #[serde(default)]\n    pub compilation_timestamp: Option<String>,\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\n#[serde(default)]\npub struct Compilation {\n    #[serde(skip_serializing)]\n    pub jstprove: BackendCompilation,\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct SliceShapeWrapper {\n    #[serde(default)]\n    pub tensor_shape: TensorShape,\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct SliceMetadata {\n    #[serde(default)]\n    pub index: usize,\n    #[serde(default)]\n    pub filename: String,\n    #[serde(default)]\n    pub path: String,\n    #[serde(default)]\n    pub relative_path: String,\n    #[serde(default)]\n    pub shape: SliceShapeWrapper,\n    #[serde(default)]\n    pub dependencies: Dependencies,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub tiling: Option<TilingInfo>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub channel_split: Option<ChannelSplitInfo>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub dim_split: Option<DimSplitInfo>,\n    #[serde(default)]\n    pub compilation: Compilation,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub slice_metadata: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub slice_metadata_relative_path: Option<String>,\n}\n\nimpl SliceMetadata {\n    pub fn split_strategy(&self) -> Option<super::tiling::SplitStrategy<'_>> {\n        use super::tiling::SplitStrategy;\n        self.tiling\n            .as_ref()\n            .map(SplitStrategy::Tiled)\n            .or_else(|| self.channel_split.as_ref().map(SplitStrategy::ChannelSplit))\n            .or_else(|| self.dim_split.as_ref().map(SplitStrategy::DimSplit))\n    }\n\n    pub fn output_names(&self) -> &[String] {\n        &self.dependencies.output\n    }\n\n    pub fn resolve_onnx(\n        &self,\n        slices_dir: &std::path::Path,\n    ) -> crate::error::Result<std::path::PathBuf> {\n        if self.relative_path.is_empty() {\n            Ok(slices_dir.join(\"model.onnx\"))\n        } else {\n            crate::utils::paths::resolve_relative_path(slices_dir, &self.relative_path)\n        }\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct RunSliceMetadata {\n    #[serde(default)]\n    pub path: String,\n    #[serde(default)]\n    pub input_shape: Vec<Vec<i64>>,\n    #[serde(default)]\n    pub output_shape: Vec<Vec<i64>>,\n    #[serde(default)]\n    pub dependencies: Dependencies,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub tiling: Option<TilingInfo>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub channel_split: Option<ChannelSplitInfo>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub dim_split: Option<DimSplitInfo>,\n    #[serde(default)]\n    pub backend: BackendKind,\n    #[serde(\n        default,\n        skip_serializing_if = \"Option::is_none\",\n        alias = \"circuit_path\"\n    )]\n    pub jstprove_circuit_path: Option<String>,\n    #[serde(\n        default,\n        skip_serializing_if = \"Option::is_none\",\n        alias = \"settings_path\"\n    )]\n    pub jstprove_settings_path: Option<String>,\n}\n\nimpl RunSliceMetadata {\n    pub fn split_strategy(&self) -> Option<super::tiling::SplitStrategy<'_>> {\n        use super::tiling::SplitStrategy;\n        self.tiling\n            .as_ref()\n            .map(SplitStrategy::Tiled)\n            .or_else(|| self.channel_split.as_ref().map(SplitStrategy::ChannelSplit))\n            .or_else(|| self.dim_split.as_ref().map(SplitStrategy::DimSplit))\n    }\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct ModelMetadata {\n    #[serde(default)]\n    pub original_model: String,\n    #[serde(default)]\n    pub model_type: String,\n    #[serde(default)]\n    pub input_shape: Vec<Vec<i64>>,\n    #[serde(default)]\n    pub output_shapes: Vec<Vec<i64>>,\n    #[serde(default)]\n    pub output_names: Vec<String>,\n    #[serde(default)]\n    pub slice_points: Vec<usize>,\n    #[serde(default)]\n    pub slices: Vec<SliceMetadata>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub dsperse_version: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub dsperse_rev: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub jstprove_version: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub jstprove_rev: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub traced_shapes: Option<HashMap<String, Vec<i64>>>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub traced_types: Option<HashMap<String, i32>>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub original_model_path: Option<String>,\n    #[serde(default, skip_serializing_if = \"Vec::is_empty\")]\n    pub folded_constant_names: Vec<String>,\n}\n\nimpl ModelMetadata {\n    pub fn load(path: &std::path::Path) -> crate::error::Result<Self> {\n        let data = crate::utils::limits::read_checked(path)?;\n        rmp_serde::from_slice(&data).map_err(Into::into)\n    }\n\n    pub fn save(&self, path: &std::path::Path) -> crate::error::Result<()> {\n        if let Some(parent) = path.parent() {\n            std::fs::create_dir_all(parent)\n                .map_err(|e| crate::error::DsperseError::io(e, parent))?;\n        }\n        let data = rmp_serde::to_vec_named(self)?;\n        let tmp_path = path.with_extension(\"msgpack.tmp\");\n        std::fs::write(&tmp_path, &data)\n            .map_err(|e| crate::error::DsperseError::io(e, &tmp_path))?;\n        std::fs::rename(&tmp_path, path).map_err(|e| crate::error::DsperseError::io(e, path))\n    }\n\n    pub fn stamp_version(&mut self) {\n        let ver = crate::version::dsperse_artifact_version();\n        self.dsperse_version = Some(ver.dsperse_version);\n        self.dsperse_rev = ver.dsperse_rev;\n        self.jstprove_version = Some(ver.jstprove_version);\n        self.jstprove_rev = ver.jstprove_rev;\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/schema/mod.rs",
    "content": "pub mod execution;\npub mod metadata;\npub mod tiling;\n\npub use execution::*;\npub use metadata::*;\npub use tiling::*;\n"
  },
  {
    "path": "crates/dsperse/src/schema/tiling.rs",
    "content": "use serde::{self, Deserialize, Deserializer, Serialize};\n\n#[derive(Debug, Clone)]\npub enum SplitStrategy<'a> {\n    Tiled(&'a TilingInfo),\n    ChannelSplit(&'a ChannelSplitInfo),\n    DimSplit(&'a DimSplitInfo),\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct TileInfo {\n    #[serde(default)]\n    pub path: String,\n    #[serde(default = \"default_pair_zero\")]\n    pub conv_out: [i64; 2],\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub jstprove_circuit_path: Option<String>,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct TilingInfo {\n    #[serde(default)]\n    pub slice_idx: usize,\n    #[serde(default)]\n    pub tile_size: usize,\n    #[serde(default = \"default_one\")]\n    pub num_tiles: usize,\n    #[serde(default = \"default_one\")]\n    pub tiles_y: usize,\n    #[serde(default = \"default_one\")]\n    pub tiles_x: usize,\n    #[serde(default = \"default_quad_zero\", deserialize_with = \"deserialize_halo\")]\n    pub halo: [i64; 4],\n    #[serde(default = \"default_pair_zero\")]\n    pub out_tile: [i64; 2],\n    #[serde(default = \"default_pair_one\")]\n    pub stride: [i64; 2],\n    #[serde(default)]\n    pub c_in: usize,\n    #[serde(default)]\n    pub c_out: usize,\n    #[serde(default = \"default_input_name\")]\n    pub input_name: String,\n    #[serde(default = \"default_output_name\")]\n    pub output_name: String,\n    #[serde(default, skip_serializing_if = \"Vec::is_empty\")]\n    pub input_names: Vec<String>,\n    #[serde(default = \"default_four\")]\n    pub ndim: usize,\n    #[serde(default)]\n    pub h: usize,\n    #[serde(default)]\n    pub w: usize,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub tile: Option<TileInfo>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub tiles: Option<Vec<TileInfo>>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub segment_size: Option<usize>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub total_elements: Option<usize>,\n    #[serde(default, skip_serializing_if = \"Vec::is_empty\")]\n    pub original_shape: Vec<i64>,\n}\n\nimpl TilingInfo {\n    pub fn all_input_names(&self) -> Vec<&str> {\n        if self.input_names.is_empty() {\n            vec![&self.input_name]\n        } else {\n            self.input_names.iter().map(|s| s.as_str()).collect()\n        }\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct ChannelGroupInfo {\n    #[serde(default)]\n    pub group_idx: usize,\n    #[serde(default)]\n    pub c_start: usize,\n    #[serde(default)]\n    pub c_end: usize,\n    #[serde(default)]\n    pub path: String,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub jstprove_circuit_path: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub jstprove_settings_path: Option<String>,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct ChannelSplitInfo {\n    #[serde(default)]\n    pub slice_idx: usize,\n    #[serde(default)]\n    pub c_in: usize,\n    #[serde(default)]\n    pub c_out: usize,\n    #[serde(default = \"default_one\")]\n    pub num_groups: usize,\n    #[serde(default)]\n    pub channels_per_group: usize,\n    #[serde(default = \"default_input_name\")]\n    pub input_name: String,\n    #[serde(default = \"default_output_name\")]\n    pub output_name: String,\n    #[serde(default)]\n    pub h: usize,\n    #[serde(default)]\n    pub w: usize,\n    #[serde(default)]\n    pub out_h: usize,\n    #[serde(default)]\n    pub out_w: usize,\n    #[serde(default)]\n    pub groups: Vec<ChannelGroupInfo>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub bias_path: Option<String>,\n}\n\nfn default_one() -> usize {\n    1\n}\n\nfn default_four() -> usize {\n    4\n}\n\nfn default_pair_zero() -> [i64; 2] {\n    [0, 0]\n}\n\nfn default_pair_one() -> [i64; 2] {\n    [1, 1]\n}\n\nfn default_quad_zero() -> [i64; 4] {\n    [0, 0, 0, 0]\n}\n\nfn deserialize_halo<'de, D>(deserializer: D) -> std::result::Result<[i64; 4], D::Error>\nwhere\n    D: Deserializer<'de>,\n{\n    let v: Vec<i64> = Vec::deserialize(deserializer)?;\n    match v.len() {\n        2 => Ok([v[0], v[1], v[0], v[1]]),\n        4 => Ok([v[0], v[1], v[2], v[3]]),\n        _ => Err(serde::de::Error::custom(format!(\n            \"expected 2 or 4 elements for halo, got {}\",\n            v.len()\n        ))),\n    }\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub enum DimSplitKind {\n    #[default]\n    MatMulOutputDim,\n    HeadDim,\n    BatchDim,\n}\n\n#[derive(Debug, Clone, Default, Serialize, Deserialize)]\npub struct DimSplitInfo {\n    #[serde(default)]\n    pub slice_idx: usize,\n    #[serde(default)]\n    pub split_kind: DimSplitKind,\n    #[serde(default)]\n    pub split_dim: usize,\n    #[serde(default)]\n    pub dim_size: usize,\n    #[serde(default = \"default_one\")]\n    pub num_groups: usize,\n    #[serde(default)]\n    pub elements_per_group: usize,\n    #[serde(default = \"default_input_name\")]\n    pub input_name: String,\n    #[serde(default = \"default_output_name\")]\n    pub output_name: String,\n    #[serde(default)]\n    pub concat_axis: usize,\n    #[serde(default)]\n    pub estimated_group_constraints: u64,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub weight_name: Option<String>,\n    #[serde(default)]\n    pub k_dim: usize,\n    #[serde(default)]\n    pub n_dim: usize,\n    #[serde(default = \"default_one\")]\n    pub k_chunks: usize,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub template_path: Option<String>,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub jstprove_circuit_path: Option<String>,\n}\n\nimpl DimSplitInfo {\n    pub fn from_detection(\n        d: &crate::slicer::autotiler::DimSplitDetection,\n        slice_idx: usize,\n        template_path: Option<String>,\n    ) -> Self {\n        let estimated_group_constraints = if d.k_chunks > 1 {\n            (d.k_dim.div_ceil(d.k_chunks) * d.n_dim * 2) as u64\n        } else if d.num_groups > 0 {\n            d.estimated_constraints / d.num_groups as u64\n        } else {\n            d.estimated_constraints\n        };\n        Self {\n            slice_idx,\n            split_kind: d.split_kind.clone(),\n            split_dim: d.split_dim,\n            dim_size: d.dim_size,\n            num_groups: d.num_groups,\n            elements_per_group: d.elements_per_group,\n            input_name: d.input_name.clone(),\n            output_name: d.output_name.clone(),\n            concat_axis: d.concat_axis,\n            estimated_group_constraints,\n            weight_name: d.weight_name.clone(),\n            k_dim: d.k_dim,\n            n_dim: d.n_dim,\n            k_chunks: d.k_chunks,\n            template_path,\n            jstprove_circuit_path: None,\n        }\n    }\n}\n\nfn default_input_name() -> String {\n    \"input\".to_string()\n}\n\nfn default_output_name() -> String {\n    \"output\".to_string()\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/analyzer.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::Path;\n\nuse serde::{Deserialize, Serialize};\n\nuse super::onnx_proto::{self, GraphProto, ModelProto, TensorProto};\nuse crate::error::{DsperseError, Result};\nuse crate::schema::metadata::Dependencies;\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct NodeAnalysis {\n    pub index: usize,\n    pub slice_name: String,\n    pub node_type: String,\n    pub parameter_details: HashMap<String, ParameterDetail>,\n    pub dependencies: NodeDependencies,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct ParameterDetail {\n    pub shape: Vec<i64>,\n    pub size: usize,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct NodeDependencies {\n    pub input: Vec<String>,\n    pub output: Vec<String>,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct AnalysisResult {\n    pub original_model: Option<String>,\n    pub model_type: String,\n    pub node_count: usize,\n    pub initializer_count: usize,\n    pub input_shape: Vec<Vec<i64>>,\n    pub output_shapes: Vec<Vec<i64>>,\n    pub output_names: Vec<String>,\n    pub opset_version: Option<i64>,\n    pub nodes: HashMap<String, NodeAnalysis>,\n    pub initializer_names: HashSet<String>,\n}\n\npub fn analyze(model: &ModelProto, onnx_path: Option<&Path>) -> Result<AnalysisResult> {\n    let graph = model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Onnx(\"model has no graph\".into()))?;\n    let initializer_map: HashMap<&str, &TensorProto> = graph\n        .initializer\n        .iter()\n        .map(|i| (i.name.as_str(), i))\n        .collect();\n\n    let input_shapes = get_model_input_shapes(graph, &initializer_map);\n    let output_shapes = get_model_output_shapes(graph);\n    let output_names = get_model_output_names(graph);\n\n    let mut nodes = HashMap::new();\n    for (i, node) in graph.node.iter().enumerate() {\n        let node_key = if node.name.is_empty() {\n            format!(\"{}_{}\", node.op_type, i)\n        } else {\n            node.name.clone()\n        };\n\n        let parameter_details = get_parameter_details(node, &initializer_map);\n\n        let mut inputs: Vec<String> = node\n            .input\n            .iter()\n            .filter(|s| !s.is_empty())\n            .cloned()\n            .collect();\n        if super::is_control_flow(&node.op_type) {\n            let outer_refs = super::collect_subgraph_outer_refs(node, graph);\n            for r in outer_refs {\n                if !inputs.contains(&r) {\n                    inputs.push(r);\n                }\n            }\n        }\n\n        nodes.insert(\n            node_key,\n            NodeAnalysis {\n                index: i,\n                slice_name: format!(\"{}_{}\", node.op_type, i),\n                node_type: node.op_type.clone(),\n                parameter_details,\n                dependencies: NodeDependencies {\n                    input: inputs,\n                    output: node.output.clone(),\n                },\n            },\n        );\n    }\n\n    let opset_version = model\n        .opset_import\n        .iter()\n        .find(|o| o.domain.is_empty() || o.domain == \"ai.onnx\")\n        .map(|o| o.version);\n\n    if let Some(v) = opset_version\n        && v < 18\n    {\n        tracing::warn!(opset = v, \"opset < 18 detected; continuing anyway\");\n    }\n\n    let initializer_names: HashSet<String> =\n        graph.initializer.iter().map(|i| i.name.clone()).collect();\n\n    Ok(AnalysisResult {\n        original_model: onnx_path.map(|p| p.to_string_lossy().to_string()),\n        model_type: \"ONNX\".to_string(),\n        node_count: graph.node.len(),\n        initializer_count: graph.initializer.len(),\n        input_shape: input_shapes,\n        output_shapes,\n        output_names,\n        opset_version,\n        nodes,\n        initializer_names,\n    })\n}\n\nfn get_model_input_shapes(\n    graph: &GraphProto,\n    initializer_map: &HashMap<&str, &TensorProto>,\n) -> Vec<Vec<i64>> {\n    graph\n        .input\n        .iter()\n        .filter(|inp| !initializer_map.contains_key(inp.name.as_str()))\n        .map(onnx_proto::vi_shape)\n        .collect()\n}\n\nfn get_model_output_shapes(graph: &GraphProto) -> Vec<Vec<i64>> {\n    graph.output.iter().map(onnx_proto::vi_shape).collect()\n}\n\nfn get_model_output_names(graph: &GraphProto) -> Vec<String> {\n    graph.output.iter().map(|o| o.name.clone()).collect()\n}\n\nfn get_parameter_details(\n    node: &onnx_proto::NodeProto,\n    initializer_map: &HashMap<&str, &TensorProto>,\n) -> HashMap<String, ParameterDetail> {\n    let mut details = HashMap::new();\n    if !matches!(node.op_type.as_str(), \"Conv\" | \"Gemm\" | \"MatMul\") {\n        return details;\n    }\n    for inp_name in &node.input {\n        if let Some(init) = initializer_map.get(inp_name.as_str()) {\n            let size: usize = init.dims.iter().map(|&d| d as usize).product();\n            if size > 0 {\n                details.insert(\n                    inp_name.clone(),\n                    ParameterDetail {\n                        shape: init.dims.clone(),\n                        size,\n                    },\n                );\n            }\n        }\n    }\n    details\n}\n\npub fn get_segment_dependencies(\n    analysis: &AnalysisResult,\n    start_idx: usize,\n    end_idx: usize,\n) -> Dependencies {\n    let mut inputs = Vec::new();\n    let mut output_map: HashMap<String, bool> = HashMap::new();\n\n    let mut sorted_nodes: Vec<&NodeAnalysis> = analysis\n        .nodes\n        .values()\n        .filter(|n| n.index >= start_idx && n.index < end_idx)\n        .collect();\n    sorted_nodes.sort_by_key(|n| n.index);\n\n    let mut consumed_in_segment: HashSet<String> = HashSet::new();\n    for node in &sorted_nodes {\n        for out in &node.dependencies.output {\n            output_map.insert(out.clone(), true);\n        }\n        for inp in &node.dependencies.input {\n            if output_map.contains_key(inp) {\n                consumed_in_segment.insert(inp.clone());\n            }\n            if !output_map.contains_key(inp) && !inputs.contains(inp) {\n                inputs.push(inp.clone());\n            }\n        }\n    }\n\n    let model_output_set: HashSet<&str> =\n        analysis.output_names.iter().map(|s| s.as_str()).collect();\n\n    let mut outputs: Vec<String> = output_map\n        .keys()\n        .filter(|output| {\n            if inputs.contains(output) {\n                return false;\n            }\n            // Exclude tensors consumed by a later node in the same segment\n            // unless they are also model-level final outputs. The materializer\n            // only promotes internally-consumed tensors to graph outputs when\n            // a downstream segment needs them; the metadata list must match.\n            if consumed_in_segment.contains(output.as_str())\n                && !model_output_set.contains(output.as_str())\n            {\n                return false;\n            }\n            true\n        })\n        .cloned()\n        .collect();\n    outputs.sort();\n\n    let filtered = inputs\n        .iter()\n        .filter(|name| !analysis.initializer_names.contains(name.as_str()))\n        .cloned()\n        .collect::<Vec<_>>();\n\n    let filtered_inputs = if filtered.is_empty() && !inputs.is_empty() {\n        vec![inputs[0].clone()]\n    } else {\n        filtered\n    };\n\n    Dependencies {\n        input: inputs,\n        output: outputs,\n        filtered_inputs,\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    fn make_node(\n        op: &str,\n        idx: usize,\n        inputs: Vec<&str>,\n        outputs: Vec<&str>,\n    ) -> onnx_proto::NodeProto {\n        onnx_proto::NodeProto {\n            op_type: op.into(),\n            name: format!(\"{}_{}\", op, idx),\n            input: inputs.into_iter().map(String::from).collect(),\n            output: outputs.into_iter().map(String::from).collect(),\n            attribute: vec![],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        }\n    }\n\n    fn make_model_with_nodes(nodes: Vec<onnx_proto::NodeProto>) -> ModelProto {\n        let input = onnx_proto::make_tensor_value_info(\"x\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let output = onnx_proto::make_tensor_value_info(\"y\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let graph = onnx_proto::make_graph(\"test\", nodes, vec![input], vec![output], vec![]);\n        onnx_proto::make_model(graph, 13)\n    }\n\n    fn make_model_with_initializers(\n        nodes: Vec<onnx_proto::NodeProto>,\n        initializers: Vec<TensorProto>,\n    ) -> ModelProto {\n        let input = onnx_proto::make_tensor_value_info(\"x\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let output = onnx_proto::make_tensor_value_info(\"y\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let graph = onnx_proto::make_graph(\"test\", nodes, vec![input], vec![output], initializers);\n        onnx_proto::make_model(graph, 13)\n    }\n\n    #[test]\n    fn analyze_empty_model() {\n        let model = make_model_with_nodes(vec![]);\n        let result = analyze(&model, None).unwrap();\n        assert_eq!(result.node_count, 0);\n        assert!(result.nodes.is_empty());\n        assert_eq!(result.model_type, \"ONNX\");\n    }\n\n    #[test]\n    fn analyze_single_relu() {\n        let model = make_model_with_nodes(vec![make_node(\"Relu\", 0, vec![\"x\"], vec![\"y\"])]);\n        let result = analyze(&model, None).unwrap();\n        assert_eq!(result.node_count, 1);\n        let node = result.nodes.values().next().unwrap();\n        assert_eq!(node.node_type, \"Relu\");\n        assert!(node.parameter_details.is_empty());\n    }\n\n    #[test]\n    fn analyze_conv_with_initializer() {\n        let weight_data: Vec<f32> = vec![1.0; 27];\n        let weight_tensor = onnx_proto::make_tensor(\n            \"conv_weight\",\n            TensorProto::FLOAT,\n            &[1, 3, 3, 3],\n            weight_data,\n        );\n        let conv = make_node(\"Conv\", 0, vec![\"x\", \"conv_weight\"], vec![\"y\"]);\n        let model = make_model_with_initializers(vec![conv], vec![weight_tensor]);\n        let result = analyze(&model, None).unwrap();\n        assert_eq!(result.initializer_count, 1);\n        let node = result.nodes.values().next().unwrap();\n        assert!(!node.parameter_details.is_empty());\n        let detail = node.parameter_details.get(\"conv_weight\").unwrap();\n        assert_eq!(detail.shape, vec![1, 3, 3, 3]);\n        assert_eq!(detail.size, 27);\n    }\n\n    #[test]\n    fn analyze_non_param_op_has_no_details() {\n        let weight_data: Vec<f32> = vec![1.0; 27];\n        let weight_tensor =\n            onnx_proto::make_tensor(\"add_weight\", TensorProto::FLOAT, &[1, 3, 3, 3], weight_data);\n        let add = make_node(\"Add\", 0, vec![\"x\", \"add_weight\"], vec![\"y\"]);\n        let model = make_model_with_initializers(vec![add], vec![weight_tensor]);\n        let result = analyze(&model, None).unwrap();\n        let node = result.nodes.values().next().unwrap();\n        assert!(node.parameter_details.is_empty());\n    }\n\n    #[test]\n    fn analyze_model_no_graph() {\n        let model = ModelProto {\n            graph: None,\n            ..Default::default()\n        };\n        assert!(analyze(&model, None).is_err());\n    }\n\n    #[test]\n    fn analyze_dependencies_tracked() {\n        let conv = make_node(\"Conv\", 0, vec![\"x\", \"w\"], vec![\"conv_out\"]);\n        let relu = make_node(\"Relu\", 1, vec![\"conv_out\"], vec![\"y\"]);\n        let model = make_model_with_nodes(vec![conv, relu]);\n        let result = analyze(&model, None).unwrap();\n        assert_eq!(result.node_count, 2);\n\n        let relu_node = result\n            .nodes\n            .values()\n            .find(|n| n.node_type == \"Relu\")\n            .unwrap();\n        assert_eq!(relu_node.dependencies.input, vec![\"conv_out\"]);\n        assert_eq!(relu_node.dependencies.output, vec![\"y\"]);\n    }\n\n    #[test]\n    fn analyze_unnamed_nodes_get_generated_keys() {\n        let mut node = make_node(\"Relu\", 0, vec![\"x\"], vec![\"y\"]);\n        node.name = String::new();\n        let model = make_model_with_nodes(vec![node]);\n        let result = analyze(&model, None).unwrap();\n        assert!(result.nodes.contains_key(\"Relu_0\"));\n    }\n\n    #[test]\n    fn get_segment_dependencies_basic() {\n        let mut nodes = HashMap::new();\n        nodes.insert(\n            \"conv\".into(),\n            NodeAnalysis {\n                index: 0,\n                slice_name: \"Conv_0\".into(),\n                node_type: \"Conv\".into(),\n                parameter_details: HashMap::new(),\n                dependencies: NodeDependencies {\n                    input: vec![\"x\".into(), \"w\".into()],\n                    output: vec![\"conv_out\".into()],\n                },\n            },\n        );\n        nodes.insert(\n            \"relu\".into(),\n            NodeAnalysis {\n                index: 1,\n                slice_name: \"Relu_1\".into(),\n                node_type: \"Relu\".into(),\n                parameter_details: HashMap::new(),\n                dependencies: NodeDependencies {\n                    input: vec![\"conv_out\".into()],\n                    output: vec![\"relu_out\".into()],\n                },\n            },\n        );\n        let analysis = AnalysisResult {\n            original_model: None,\n            model_type: \"ONNX\".into(),\n            node_count: 2,\n            initializer_count: 1,\n            input_shape: vec![],\n            output_shapes: vec![],\n            output_names: vec![],\n            opset_version: Some(13),\n            nodes,\n            initializer_names: HashSet::from([\"w\".into()]),\n        };\n        let deps = get_segment_dependencies(&analysis, 0, 2);\n        assert!(deps.output.contains(&\"relu_out\".to_string()));\n        assert!(!deps.filtered_inputs.contains(&\"w\".to_string()));\n    }\n\n    fn make_attribute_graph(\n        name: &str,\n        graph: onnx_proto::GraphProto,\n    ) -> onnx_proto::AttributeProto {\n        onnx_proto::AttributeProto {\n            name: name.to_string(),\n            r#type: onnx_proto::onnx::attribute_proto::AttributeType::Graph as i32,\n            g: Some(graph),\n            ..Default::default()\n        }\n    }\n\n    #[test]\n    fn analyze_loop_captures_outer_scope_refs() {\n        let relu = make_node(\"Relu\", 0, vec![\"x\"], vec![\"relu_out\"]);\n\n        let body_node = onnx_proto::NodeProto {\n            op_type: \"Add\".into(),\n            name: \"body_add\".into(),\n            input: vec![\"body_in\".into(), \"relu_out\".into()],\n            output: vec![\"body_out\".into()],\n            attribute: vec![],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n        let body_input =\n            onnx_proto::make_tensor_value_info(\"body_in\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let body_cond_in = onnx_proto::make_tensor_value_info(\"cond_in\", TensorProto::BOOL, &[]);\n        let body_cond_out = onnx_proto::make_tensor_value_info(\"cond_out\", TensorProto::BOOL, &[]);\n        let body_output =\n            onnx_proto::make_tensor_value_info(\"body_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let body_graph = onnx_proto::make_graph(\n            \"loop_body\",\n            vec![body_node],\n            vec![body_cond_in.clone(), body_input],\n            vec![body_cond_out, body_output],\n            vec![],\n        );\n\n        let loop_node = onnx_proto::NodeProto {\n            op_type: \"Loop\".into(),\n            name: \"Loop_1\".into(),\n            input: vec![\"trip_count\".into(), \"cond\".into(), \"init_val\".into()],\n            output: vec![\"loop_out\".into()],\n            attribute: vec![make_attribute_graph(\"body\", body_graph)],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n\n        let input = onnx_proto::make_tensor_value_info(\"x\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let output =\n            onnx_proto::make_tensor_value_info(\"loop_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let trip_vi = onnx_proto::make_tensor_value_info(\"trip_count\", TensorProto::INT64, &[]);\n        let cond_vi = onnx_proto::make_tensor_value_info(\"cond\", TensorProto::BOOL, &[]);\n        let init_vi =\n            onnx_proto::make_tensor_value_info(\"init_val\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let graph = onnx_proto::make_graph(\n            \"test\",\n            vec![relu, loop_node],\n            vec![input, trip_vi, cond_vi, init_vi],\n            vec![output],\n            vec![],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n\n        let result = analyze(&model, None).unwrap();\n        let loop_analysis = result\n            .nodes\n            .values()\n            .find(|n| n.node_type == \"Loop\")\n            .unwrap();\n\n        let loop_inputs = &loop_analysis.dependencies.input;\n        assert!(\n            loop_inputs.contains(&\"relu_out\".to_string()),\n            \"Loop node must include outer-scope ref 'relu_out' in its dependencies, got: {:?}\",\n            loop_inputs\n        );\n        for local in &[\"body_in\", \"body_out\", \"cond_in\", \"cond_out\"] {\n            assert!(\n                !loop_inputs.contains(&local.to_string()),\n                \"body-local name '{}' must not leak into Loop dependencies, got: {:?}\",\n                local,\n                loop_inputs\n            );\n        }\n    }\n\n    #[test]\n    fn analyze_if_captures_outer_scope_refs() {\n        let relu = make_node(\"Relu\", 0, vec![\"x\"], vec![\"relu_out\"]);\n\n        let then_node = onnx_proto::NodeProto {\n            op_type: \"Identity\".into(),\n            name: \"then_id\".into(),\n            input: vec![\"relu_out\".into()],\n            output: vec![\"then_out\".into()],\n            attribute: vec![],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n        let then_output =\n            onnx_proto::make_tensor_value_info(\"then_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let then_graph = onnx_proto::make_graph(\n            \"then_branch\",\n            vec![then_node],\n            vec![],\n            vec![then_output],\n            vec![],\n        );\n\n        let else_node = onnx_proto::NodeProto {\n            op_type: \"Neg\".into(),\n            name: \"else_neg\".into(),\n            input: vec![\"relu_out\".into()],\n            output: vec![\"else_out\".into()],\n            attribute: vec![],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n        let else_output =\n            onnx_proto::make_tensor_value_info(\"else_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let else_graph = onnx_proto::make_graph(\n            \"else_branch\",\n            vec![else_node],\n            vec![],\n            vec![else_output],\n            vec![],\n        );\n\n        let if_node = onnx_proto::NodeProto {\n            op_type: \"If\".into(),\n            name: \"If_1\".into(),\n            input: vec![\"cond\".into()],\n            output: vec![\"if_out\".into()],\n            attribute: vec![\n                make_attribute_graph(\"then_branch\", then_graph),\n                make_attribute_graph(\"else_branch\", else_graph),\n            ],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n\n        let input = onnx_proto::make_tensor_value_info(\"x\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let cond_vi = onnx_proto::make_tensor_value_info(\"cond\", TensorProto::BOOL, &[]);\n        let output =\n            onnx_proto::make_tensor_value_info(\"if_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let graph = onnx_proto::make_graph(\n            \"test\",\n            vec![relu, if_node],\n            vec![input, cond_vi],\n            vec![output],\n            vec![],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n\n        let result = analyze(&model, None).unwrap();\n        let if_analysis = result.nodes.values().find(|n| n.node_type == \"If\").unwrap();\n\n        let if_inputs = &if_analysis.dependencies.input;\n        assert!(\n            if_inputs.contains(&\"relu_out\".to_string()),\n            \"If node must include outer-scope ref 'relu_out' from both branches, got: {:?}\",\n            if_inputs\n        );\n        for local in &[\"then_out\", \"else_out\"] {\n            assert!(\n                !if_inputs.contains(&local.to_string()),\n                \"branch-local name '{}' must not leak into If dependencies, got: {:?}\",\n                local,\n                if_inputs\n            );\n        }\n    }\n\n    #[test]\n    fn segment_deps_include_subgraph_outer_refs() {\n        let relu = make_node(\"Relu\", 0, vec![\"x\"], vec![\"relu_out\"]);\n\n        let body_node = onnx_proto::NodeProto {\n            op_type: \"Add\".into(),\n            name: \"body_add\".into(),\n            input: vec![\"body_in\".into(), \"relu_out\".into()],\n            output: vec![\"body_out\".into()],\n            attribute: vec![],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n        let body_input =\n            onnx_proto::make_tensor_value_info(\"body_in\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let body_cond_in = onnx_proto::make_tensor_value_info(\"cond_in\", TensorProto::BOOL, &[]);\n        let body_cond_out = onnx_proto::make_tensor_value_info(\"cond_out\", TensorProto::BOOL, &[]);\n        let body_output =\n            onnx_proto::make_tensor_value_info(\"body_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let body_graph = onnx_proto::make_graph(\n            \"loop_body\",\n            vec![body_node],\n            vec![body_cond_in, body_input],\n            vec![body_cond_out, body_output],\n            vec![],\n        );\n\n        let loop_node = onnx_proto::NodeProto {\n            op_type: \"Loop\".into(),\n            name: \"Loop_1\".into(),\n            input: vec![\"trip_count\".into(), \"cond\".into(), \"init_val\".into()],\n            output: vec![\"loop_out\".into()],\n            attribute: vec![make_attribute_graph(\"body\", body_graph)],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n\n        let input = onnx_proto::make_tensor_value_info(\"x\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let output =\n            onnx_proto::make_tensor_value_info(\"loop_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let trip_vi = onnx_proto::make_tensor_value_info(\"trip_count\", TensorProto::INT64, &[]);\n        let cond_vi = onnx_proto::make_tensor_value_info(\"cond\", TensorProto::BOOL, &[]);\n        let init_vi =\n            onnx_proto::make_tensor_value_info(\"init_val\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let graph = onnx_proto::make_graph(\n            \"test\",\n            vec![relu, loop_node],\n            vec![input, trip_vi, cond_vi, init_vi],\n            vec![output],\n            vec![],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n        let result = analyze(&model, None).unwrap();\n\n        let deps = get_segment_dependencies(&result, 1, 2);\n        assert!(\n            deps.input.contains(&\"relu_out\".to_string()),\n            \"segment containing only Loop must list 'relu_out' as input dep, got: {:?}\",\n            deps.input\n        );\n        for local in &[\"body_in\", \"body_out\", \"cond_in\", \"cond_out\"] {\n            assert!(\n                !deps.input.contains(&local.to_string()),\n                \"body-local name '{}' must not appear in segment inputs, got: {:?}\",\n                local,\n                deps.input\n            );\n        }\n    }\n\n    #[test]\n    fn analyze_nested_subgraph_captures_outer_scope_refs() {\n        let relu = make_node(\"Relu\", 0, vec![\"x\"], vec![\"relu_out\"]);\n\n        let inner_add = onnx_proto::NodeProto {\n            op_type: \"Add\".into(),\n            name: \"inner_add\".into(),\n            input: vec![\"inner_in\".into(), \"relu_out\".into()],\n            output: vec![\"inner_out\".into()],\n            attribute: vec![],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n        let inner_input =\n            onnx_proto::make_tensor_value_info(\"inner_in\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let inner_output =\n            onnx_proto::make_tensor_value_info(\"inner_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let inner_graph = onnx_proto::make_graph(\n            \"inner_then\",\n            vec![inner_add],\n            vec![inner_input],\n            vec![inner_output],\n            vec![],\n        );\n\n        let if_node_in_body = onnx_proto::NodeProto {\n            op_type: \"If\".into(),\n            name: \"nested_if\".into(),\n            input: vec![\"body_cond\".into()],\n            output: vec![\"body_out\".into()],\n            attribute: vec![\n                make_attribute_graph(\"then_branch\", inner_graph.clone()),\n                make_attribute_graph(\"else_branch\", inner_graph),\n            ],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n        let body_cond_in = onnx_proto::make_tensor_value_info(\"cond_in\", TensorProto::BOOL, &[]);\n        let body_cond = onnx_proto::make_tensor_value_info(\"body_cond\", TensorProto::BOOL, &[]);\n        let body_cond_out = onnx_proto::make_tensor_value_info(\"cond_out\", TensorProto::BOOL, &[]);\n        let body_output =\n            onnx_proto::make_tensor_value_info(\"body_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let body_graph = onnx_proto::make_graph(\n            \"loop_body\",\n            vec![if_node_in_body],\n            vec![body_cond_in, body_cond],\n            vec![body_cond_out, body_output],\n            vec![],\n        );\n\n        let loop_node = onnx_proto::NodeProto {\n            op_type: \"Loop\".into(),\n            name: \"Loop_1\".into(),\n            input: vec![\"trip_count\".into(), \"cond\".into(), \"init_val\".into()],\n            output: vec![\"loop_out\".into()],\n            attribute: vec![make_attribute_graph(\"body\", body_graph)],\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        };\n\n        let input = onnx_proto::make_tensor_value_info(\"x\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let output =\n            onnx_proto::make_tensor_value_info(\"loop_out\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let trip_vi = onnx_proto::make_tensor_value_info(\"trip_count\", TensorProto::INT64, &[]);\n        let cond_vi = onnx_proto::make_tensor_value_info(\"cond\", TensorProto::BOOL, &[]);\n        let init_vi =\n            onnx_proto::make_tensor_value_info(\"init_val\", TensorProto::FLOAT, &[1, 3, 8, 8]);\n        let graph = onnx_proto::make_graph(\n            \"test\",\n            vec![relu, loop_node],\n            vec![input, trip_vi, cond_vi, init_vi],\n            vec![output],\n            vec![],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n\n        let result = analyze(&model, None).unwrap();\n        let loop_analysis = result\n            .nodes\n            .values()\n            .find(|n| n.node_type == \"Loop\")\n            .unwrap();\n\n        let nested_inputs = &loop_analysis.dependencies.input;\n        assert!(\n            nested_inputs.contains(&\"relu_out\".to_string()),\n            \"Loop with nested If subgraph referencing outer-scope 'relu_out' must capture it, got: {:?}\",\n            nested_inputs\n        );\n        for local in &[\"body_cond\", \"inner_in\", \"inner_out\", \"body_out\"] {\n            assert!(\n                !nested_inputs.contains(&local.to_string()),\n                \"nested-body-local name '{}' must not leak into Loop dependencies, got: {:?}\",\n                local,\n                nested_inputs\n            );\n        }\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/autotiler.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::Path;\n\nuse super::onnx_proto::{self, GraphProto, ModelProto, NodeProto, TensorProto};\nuse crate::error::Result;\nuse crate::schema::tiling::{ChannelGroupInfo, ChannelSplitInfo, DimSplitKind};\n\nfn try_pair(v: &[i64]) -> Option<[i64; 2]> {\n    if v.len() == 2 {\n        Some([v[0], v[1]])\n    } else {\n        None\n    }\n}\n\nfn try_quad(v: &[i64]) -> Option<[i64; 4]> {\n    if v.len() == 4 {\n        Some([v[0], v[1], v[2], v[3]])\n    } else {\n        None\n    }\n}\n\npub(crate) fn model_opset(model: &ModelProto) -> i64 {\n    model\n        .opset_import\n        .iter()\n        .filter(|o| o.domain.is_empty())\n        .map(|o| o.version)\n        .max()\n        .unwrap_or(13)\n}\n\nfn is_elementwise(op: &str) -> bool {\n    super::is_elementwise(op)\n}\n\n#[derive(Debug, Clone)]\npub struct ChannelSplitParams {\n    pub c_in: i64,\n    pub c_out: i64,\n    pub num_groups: i64,\n    pub channels_per_group: i64,\n    pub h: i64,\n    pub w: i64,\n    pub slice_idx: usize,\n}\n\nstruct PoolParams {\n    node_idx: usize,\n    kernel: [i64; 2],\n    stride: [i64; 2],\n    dilation: [i64; 2],\n    pads: [i64; 4],\n}\n\nimpl PoolParams {\n    fn from_node(node: &NodeProto, node_idx: usize) -> Option<PoolParams> {\n        if node.op_type != \"MaxPool\" {\n            return None;\n        }\n        let kernel = try_pair(&onnx_proto::get_attribute_ints(node, \"kernel_shape\")?)?;\n        let stride = match onnx_proto::get_attribute_ints(node, \"strides\") {\n            None => [1, 1],\n            Some(v) => try_pair(&v)?,\n        };\n        let dilation = match onnx_proto::get_attribute_ints(node, \"dilations\") {\n            None => [1, 1],\n            Some(v) => try_pair(&v)?,\n        };\n        let auto_pad = node\n            .attribute\n            .iter()\n            .find(|a| a.name == \"auto_pad\")\n            .map(|a| a.s.as_slice());\n        if matches!(auto_pad, Some(v) if !v.is_empty() && v != b\"NOTSET\") {\n            return None;\n        }\n        let pads = match onnx_proto::get_attribute_ints(node, \"pads\") {\n            None => [0, 0, 0, 0],\n            Some(v) => try_quad(&v)?,\n        };\n        let ceil_mode = onnx_proto::get_attribute_int(node, \"ceil_mode\").unwrap_or(0);\n        if ceil_mode != 0 {\n            return None;\n        }\n        if kernel.iter().any(|&v| v <= 0) || stride.iter().any(|&v| v <= 0) {\n            return None;\n        }\n        if dilation.iter().any(|&v| v <= 0) || pads.iter().any(|&v| v < 0) {\n            return None;\n        }\n        Some(PoolParams {\n            node_idx,\n            kernel,\n            stride,\n            dilation,\n            pads,\n        })\n    }\n}\n\nfn get_pool_params(graph: &GraphProto) -> Option<PoolParams> {\n    for (idx, node) in graph.node.iter().enumerate() {\n        if let Some(pp) = PoolParams::from_node(node, idx) {\n            return Some(pp);\n        }\n    }\n    None\n}\n\nstruct ConvParams {\n    node_idx: usize,\n    kernel: [i64; 2],\n    stride: [i64; 2],\n    dilation: [i64; 2],\n    pads: [i64; 4],\n    group: i64,\n    c_out: i64,\n    c_in: i64,\n}\n\nimpl ConvParams {\n    fn from_node(node: &NodeProto, node_idx: usize, graph: &GraphProto) -> Option<ConvParams> {\n        if node.op_type != \"Conv\" {\n            return None;\n        }\n        let w_name = node.input.get(1)?;\n        let w = graph.initializer.iter().find(|t| &t.name == w_name)?;\n        if w.dims.len() != 4 {\n            return None;\n        }\n        let c_out = w.dims[0];\n        let c_in = w.dims[1];\n        if c_out <= 0 || c_in <= 0 {\n            return None;\n        }\n\n        let inferred_kernel = [w.dims[2], w.dims[3]];\n        let kernel = match onnx_proto::get_attribute_ints(node, \"kernel_shape\") {\n            Some(v) => {\n                let k = try_pair(&v)?;\n                if k != inferred_kernel {\n                    return None;\n                }\n                k\n            }\n            None => inferred_kernel,\n        };\n        let stride = match onnx_proto::get_attribute_ints(node, \"strides\") {\n            None => [1, 1],\n            Some(v) => try_pair(&v)?,\n        };\n        let dilation = match onnx_proto::get_attribute_ints(node, \"dilations\") {\n            None => [1, 1],\n            Some(v) => try_pair(&v)?,\n        };\n        let auto_pad = node\n            .attribute\n            .iter()\n            .find(|a| a.name == \"auto_pad\")\n            .map(|a| a.s.as_slice());\n        if matches!(auto_pad, Some(v) if !v.is_empty() && v != b\"NOTSET\") {\n            return None;\n        }\n        let pads = match onnx_proto::get_attribute_ints(node, \"pads\") {\n            None => [0, 0, 0, 0],\n            Some(v) => try_quad(&v)?,\n        };\n        if kernel.iter().any(|&v| v <= 0) {\n            return None;\n        }\n        if stride.iter().any(|&v| v <= 0) {\n            return None;\n        }\n        if dilation.iter().any(|&v| v <= 0) {\n            return None;\n        }\n        if pads.iter().any(|&v| v < 0) {\n            return None;\n        }\n        let group = onnx_proto::get_attribute_int(node, \"group\").unwrap_or(1);\n        if group <= 0 {\n            return None;\n        }\n\n        Some(ConvParams {\n            node_idx,\n            kernel,\n            stride,\n            dilation,\n            pads,\n            group,\n            c_out,\n            c_in,\n        })\n    }\n}\n\nfn get_conv_params(graph: &GraphProto) -> Option<ConvParams> {\n    for (idx, node) in graph.node.iter().enumerate() {\n        if let Some(cp) = ConvParams::from_node(node, idx, graph) {\n            return Some(cp);\n        }\n    }\n    None\n}\n\nfn effective_kernel(kernel: [i64; 2], dilation: [i64; 2]) -> Option<[i64; 2]> {\n    let ek0 = kernel[0]\n        .checked_sub(1)?\n        .checked_mul(dilation[0])?\n        .checked_add(1)?;\n    let ek1 = kernel[1]\n        .checked_sub(1)?\n        .checked_mul(dilation[1])?\n        .checked_add(1)?;\n    Some([ek0, ek1])\n}\n\nfn conv_output_hw(\n    h_in: i64,\n    w_in: i64,\n    pads: [i64; 4],\n    kernel: [i64; 2],\n    dilation: [i64; 2],\n    stride: [i64; 2],\n) -> Option<(i64, i64)> {\n    if stride[0] <= 0 || stride[1] <= 0 {\n        return None;\n    }\n    let eff = effective_kernel(kernel, dilation)?;\n    let num_h = h_in\n        .checked_add(pads[0])?\n        .checked_add(pads[2])?\n        .checked_sub(eff[0])?;\n    let num_w = w_in\n        .checked_add(pads[1])?\n        .checked_add(pads[3])?\n        .checked_sub(eff[1])?;\n    let out_h = num_h.div_euclid(stride[0]).checked_add(1)?;\n    let out_w = num_w.div_euclid(stride[1]).checked_add(1)?;\n    if out_h <= 0 || out_w <= 0 {\n        return None;\n    }\n    Some((out_h, out_w))\n}\n\nfn compute_halo_size(pads: [i64; 4]) -> Option<[i64; 4]> {\n    if pads.iter().any(|&v| v < 0) {\n        return None;\n    }\n    Some(pads)\n}\n\nfn compute_min_spatial_tile(kernel: [i64; 2], dilation: [i64; 2]) -> Option<i64> {\n    let eff = effective_kernel(kernel, dilation)?;\n    eff[0].max(eff[1]).checked_add(1)\n}\n\nstruct SpatialKernelParams {\n    kernel: [i64; 2],\n    stride: [i64; 2],\n    dilation: [i64; 2],\n    pads: [i64; 4],\n}\n\nfn extract_spatial_kernel_params(\n    graph: &GraphProto,\n    primary_op: &str,\n) -> Option<SpatialKernelParams> {\n    if graph.input.len() > 1 {\n        return None;\n    }\n    let op_count = graph\n        .node\n        .iter()\n        .filter(|n| n.op_type == primary_op)\n        .count();\n    if op_count != 1 {\n        return None;\n    }\n    let (node_idx, kernel, stride, dilation, pads) = if primary_op == \"Conv\" {\n        let cp = get_conv_params(graph)?;\n        (cp.node_idx, cp.kernel, cp.stride, cp.dilation, cp.pads)\n    } else if primary_op == \"MaxPool\" {\n        let pp = get_pool_params(graph)?;\n        (pp.node_idx, pp.kernel, pp.stride, pp.dilation, pp.pads)\n    } else {\n        return None;\n    };\n    if node_idx != 0 {\n        return None;\n    }\n    let ops: HashSet<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();\n    if ops.iter().any(|&o| o != primary_op && !is_elementwise(o)) {\n        return None;\n    }\n    Some(SpatialKernelParams {\n        kernel,\n        stride,\n        dilation,\n        pads,\n    })\n}\n\nfn is_spatial_tileable(graph: &GraphProto, primary_op: &str) -> bool {\n    let Some(sp) = extract_spatial_kernel_params(graph, primary_op) else {\n        return false;\n    };\n    let Some(eff) = effective_kernel(sp.kernel, sp.dilation) else {\n        return false;\n    };\n    let total_pad_h = sp.pads[0] + sp.pads[2];\n    let total_pad_w = sp.pads[1] + sp.pads[3];\n    total_pad_h >= eff[0] - sp.stride[0] && total_pad_w >= eff[1] - sp.stride[1]\n}\n\nfn is_standard_conv_slice(graph: &GraphProto) -> Option<ConvParams> {\n    extract_spatial_kernel_params(graph, \"Conv\")?;\n    get_conv_params(graph)\n}\n\nfn is_tileable(graph: &GraphProto) -> bool {\n    is_spatial_tileable(graph, \"Conv\")\n}\n\nfn is_channel_splittable(graph: &GraphProto) -> bool {\n    let Some(cp) = is_standard_conv_slice(graph) else {\n        return false;\n    };\n    cp.group == 1\n}\n\nfn get_model_dimensions(graph: &GraphProto) -> Option<(String, String, i64, i64, i64)> {\n    let inp = graph.input.first()?;\n    let out = graph.output.first()?;\n    let dims = onnx_proto::vi_shape(inp);\n    if dims.len() != 4 || dims[1] <= 0 || dims[2] <= 0 || dims[3] <= 0 {\n        return None;\n    }\n    Some((\n        inp.name.clone(),\n        out.name.clone(),\n        dims[1],\n        dims[2],\n        dims[3],\n    ))\n}\n\nfn is_elementwise_only_slice(graph: &GraphProto) -> bool {\n    if graph.node.is_empty() || graph.input.is_empty() {\n        return false;\n    }\n    graph.node.iter().all(|n| is_elementwise(&n.op_type))\n}\n\nfn find_weights_and_bias(\n    graph: &GraphProto,\n    conv_node: &NodeProto,\n) -> (Option<WeightInfo>, Option<Vec<f32>>) {\n    let mut weights: Option<WeightInfo> = None;\n    let mut bias: Option<Vec<f32>> = None;\n\n    for init in &graph.initializer {\n        if conv_node.input.len() > 1 && init.name == conv_node.input[1] {\n            let data = onnx_proto::tensor_to_f32(init);\n            weights = Some(WeightInfo {\n                data,\n                dims: init.dims.clone(),\n            });\n        }\n        if conv_node.input.len() > 2 && init.name == conv_node.input[2] {\n            bias = Some(onnx_proto::tensor_to_f32(init));\n        }\n    }\n    (weights, bias)\n}\n\nstruct WeightInfo {\n    data: Vec<f32>,\n    dims: Vec<i64>,\n}\n\nstruct SlicePrologue<'a> {\n    graph: &'a GraphProto,\n    cp: ConvParams,\n    weights: Option<WeightInfo>,\n    bias: Option<Vec<f32>>,\n}\n\nfn extract_slice_prologue(model: &ModelProto) -> Option<SlicePrologue<'_>> {\n    let graph = model.graph.as_ref()?;\n    let cp = get_conv_params(graph)?;\n    let conv_node = &graph.node[cp.node_idx];\n    let (weights, bias) = find_weights_and_bias(graph, conv_node);\n    if let Some(ref w) = weights {\n        if w.dims.len() != 4 {\n            return None;\n        }\n        let c_out = usize::try_from(w.dims[0]).ok()?;\n        let c_in = usize::try_from(w.dims[1]).ok()?;\n        let kh = usize::try_from(w.dims[2]).ok()?;\n        let kw = usize::try_from(w.dims[3]).ok()?;\n        let expected = c_out.checked_mul(c_in)?.checked_mul(kh)?.checked_mul(kw)?;\n        if w.data.len() != expected {\n            return None;\n        }\n        if let Some(ref b) = bias\n            && b.len() != c_out\n        {\n            return None;\n        }\n    }\n    Some(SlicePrologue {\n        graph,\n        cp,\n        weights,\n        bias,\n    })\n}\n\nfn find_optimal_tile_size(\n    spatial_dim: i64,\n    target: i64,\n    min_tile: i64,\n    stride: i64,\n) -> Option<i64> {\n    if min_tile <= target && target < spatial_dim {\n        for tile in (min_tile..=target).rev() {\n            if spatial_dim % tile == 0 && tile % stride == 0 {\n                return Some(tile);\n            }\n        }\n    }\n    None\n}\n\nfn calculate_spatial_tile_config(\n    channels: i64,\n    h: i64,\n    w: i64,\n    tile_size: i64,\n    min_tile: i64,\n    stride: i64,\n) -> (Option<i64>, Option<&'static str>) {\n    let total = channels * h * w;\n    if total <= tile_size {\n        return (None, Some(\"already_fits\"));\n    }\n    let max_tile = ((tile_size as f64) / (channels as f64)).sqrt() as i64;\n    if max_tile < min_tile {\n        return (None, Some(\"min_tile_too_large\"));\n    }\n    let target_tile = max_tile.min(h).min(w);\n    match find_optimal_tile_size(h, target_tile, min_tile, stride) {\n        Some(t) => (Some(t), None),\n        None => (None, Some(\"no_divisor\")),\n    }\n}\n\nfn calculate_channel_split_config(\n    c_in: i64,\n    _c_out: i64,\n    h: i64,\n    w: i64,\n    tile_size: i64,\n) -> Option<(i64, i64)> {\n    if h == 0 || w == 0 {\n        return None;\n    }\n    let max_ch = tile_size / (h * w);\n    if max_ch >= 1 && max_ch < c_in {\n        let mut num_groups = (c_in + max_ch - 1) / max_ch;\n        if num_groups > 1 {\n            let mut cpg = (c_in + num_groups - 1) / num_groups;\n            while cpg * (num_groups - 1) >= c_in && num_groups > 1 {\n                num_groups -= 1;\n                cpg = (c_in + num_groups - 1) / num_groups;\n            }\n            if num_groups > 1 {\n                return Some((num_groups, cpg));\n            }\n        }\n    }\n    None\n}\n\npub const CONV_TILE_BUDGET: i64 = 512;\npub const POOL_TILE_BUDGET: i64 = 1024;\n\npub fn detect_tiling_needs(\n    model: &ModelProto,\n    tile_size: Option<usize>,\n) -> Option<TilingDetection> {\n    let graph = model.graph.as_ref()?;\n    tile_size?;\n\n    let dims_4d = get_model_dimensions(graph);\n\n    if let Some((ref inp_name, ref out_name, c_in, h, w)) = dims_4d\n        && let Some(cp) = get_conv_params(graph)\n    {\n        let budget = CONV_TILE_BUDGET;\n        let c_out = cp.c_out;\n\n        if is_tileable(graph) {\n            let min_tile = compute_min_spatial_tile(cp.kernel, cp.dilation)?;\n            let (actual_tile, _skip_reason) =\n                calculate_spatial_tile_config(c_in, h, w, budget, min_tile, cp.stride[0]);\n\n            if let Some(actual_tile) = actual_tile\n                && h % actual_tile == 0\n                && w % actual_tile == 0\n                && actual_tile % cp.stride[0] == 0\n                && actual_tile % cp.stride[1] == 0\n            {\n                let tiles_y = h / actual_tile;\n                let tiles_x = w / actual_tile;\n                if tiles_y * tiles_x >= 2 {\n                    let halo = compute_halo_size(cp.pads)?;\n                    return Some(TilingDetection::Spatial {\n                        input_name: inp_name.clone(),\n                        output_name: out_name.clone(),\n                        input_names: vec![inp_name.clone()],\n                        ndim: 4,\n                        c_in,\n                        c_out,\n                        h,\n                        w,\n                        tile_size: actual_tile,\n                        halo,\n                        tiles_y,\n                        tiles_x,\n                        out_tile: [actual_tile / cp.stride[0], actual_tile / cp.stride[1]],\n                        stride: cp.stride,\n                    });\n                }\n            }\n        }\n\n        if is_channel_splittable(graph)\n            && let Some((num_groups, cpg)) =\n                calculate_channel_split_config(c_in, c_out, h, w, budget)\n        {\n            return Some(TilingDetection::ChannelSplit {\n                input_name: inp_name.clone(),\n                output_name: out_name.clone(),\n                c_in,\n                c_out,\n                h,\n                w,\n                num_groups,\n                channels_per_group: cpg,\n            });\n        }\n    }\n\n    if let Some((ref inp_name, ref out_name, c_in, h, w)) = dims_4d\n        && is_spatial_tileable(graph, \"MaxPool\")\n        && let Some(pp) = get_pool_params(graph)\n    {\n        let budget = POOL_TILE_BUDGET;\n        let min_tile = compute_min_spatial_tile(pp.kernel, pp.dilation)?;\n        let (actual_tile, _skip_reason) =\n            calculate_spatial_tile_config(c_in, h, w, budget, min_tile, pp.stride[0]);\n\n        if let Some(actual_tile) = actual_tile\n            && h % actual_tile == 0\n            && w % actual_tile == 0\n            && actual_tile % pp.stride[0] == 0\n            && actual_tile % pp.stride[1] == 0\n        {\n            let tiles_y = h / actual_tile;\n            let tiles_x = w / actual_tile;\n            if tiles_y * tiles_x >= 2 {\n                let halo = compute_halo_size(pp.pads)?;\n                return Some(TilingDetection::Spatial {\n                    input_name: inp_name.clone(),\n                    output_name: out_name.clone(),\n                    input_names: vec![inp_name.clone()],\n                    ndim: 4,\n                    c_in,\n                    c_out: c_in,\n                    h,\n                    w,\n                    tile_size: actual_tile,\n                    halo,\n                    tiles_y,\n                    tiles_x,\n                    out_tile: [actual_tile / pp.stride[0], actual_tile / pp.stride[1]],\n                    stride: pp.stride,\n                });\n            }\n        }\n    }\n\n    if let Some(detection) = detect_elementwise_fixed_segments(graph) {\n        return Some(detection);\n    }\n\n    None\n}\n\npub const ELEMENTWISE_SEGMENT_SIZE: i64 = 1024;\n\nfn elementwise_segment_size() -> i64 {\n    std::env::var(\"DSPERSE_EW_SEGMENT_SIZE\")\n        .ok()\n        .and_then(|v| v.parse::<i64>().ok())\n        .filter(|&v| v > 0)\n        .unwrap_or(ELEMENTWISE_SEGMENT_SIZE)\n}\n\nfn detect_elementwise_fixed_segments(graph: &GraphProto) -> Option<TilingDetection> {\n    if !is_elementwise_only_slice(graph) {\n        return None;\n    }\n    let seg_size = elementwise_segment_size();\n    let out = graph.output.first()?;\n    let first_inp = graph.input.first()?;\n    let first_dims = onnx_proto::vi_shape(first_inp);\n    if first_dims.is_empty() || first_dims.iter().any(|&d| d <= 0) {\n        return None;\n    }\n    let total_elements = first_dims\n        .iter()\n        .try_fold(1i64, |acc, &d| acc.checked_mul(d))?;\n    if total_elements <= seg_size {\n        return None;\n    }\n    let last_dim = *first_dims.last().unwrap_or(&0);\n    let mut effective_seg_size = seg_size;\n    for init in &graph.initializer {\n        let vol: i64 = init.dims.iter().product();\n        if vol <= 1 || vol == seg_size {\n            continue;\n        }\n        if init.dims.len() == 1 && init.dims[0] == last_dim && last_dim > 0 {\n            effective_seg_size = last_dim;\n            continue;\n        }\n        return None;\n    }\n    let seg_size = effective_seg_size;\n    let mut input_names = Vec::with_capacity(graph.input.len());\n    for inp in &graph.input {\n        let d = onnx_proto::vi_shape(inp);\n        if d != first_dims || d.iter().any(|&v| v <= 0) {\n            return None;\n        }\n        input_names.push(inp.name.clone());\n    }\n    #[allow(clippy::manual_div_ceil)]\n    let num_segments = (total_elements + seg_size - 1) / seg_size;\n    if num_segments < 2 {\n        return None;\n    }\n    let primary_name = input_names[0].clone();\n    Some(TilingDetection::FixedSegment {\n        input_name: primary_name,\n        output_name: out.name.clone(),\n        input_names,\n        total_elements,\n        segment_size: seg_size,\n        num_segments,\n        original_shape: first_dims,\n    })\n}\n\npub const MAX_ESTIMATED_CONSTRAINTS: u64 = 750_000;\n\n/// Return the smallest divisor of `dim` that is >= `target`.  Returns\n/// `None` if no such divisor exists in `(0, dim]`, which is the\n/// signal to refuse the dim-split: pad-then-trim on the last group\n/// would inject zeros into reductions on non-split axes (Softmax,\n/// LayerNorm, ReduceMean, etc.) and contaminate the unpadded\n/// region's outputs.\nfn smallest_divisor_at_least(dim: usize, target: usize) -> Option<usize> {\n    if dim == 0 || target == 0 {\n        return None;\n    }\n    let target = target.min(dim);\n    (target..=dim).find(|&g| dim.is_multiple_of(g))\n}\n\n#[derive(Debug, Clone)]\npub struct DimSplitDetection {\n    pub split_kind: DimSplitKind,\n    pub split_dim: usize,\n    pub dim_size: usize,\n    pub num_groups: usize,\n    pub elements_per_group: usize,\n    pub input_name: String,\n    pub output_name: String,\n    pub concat_axis: usize,\n    pub estimated_constraints: u64,\n    pub weight_name: Option<String>,\n    pub k_dim: usize,\n    pub n_dim: usize,\n    pub k_chunks: usize,\n}\n\npub fn estimate_slice_constraints(nodes: &[NodeProto], shapes: &HashMap<String, Vec<i64>>) -> u64 {\n    let config = jstprove_circuits::api::EstimationConfig::bn254_defaults();\n    let mut total: u64 = 0;\n\n    let to_usize_shape = |name: &String| -> Vec<usize> {\n        shapes\n            .get(name)\n            .map(|s| s.iter().map(|&d| d.max(1) as usize).collect())\n            .unwrap_or_default()\n    };\n\n    for node in nodes {\n        let input_shapes: Vec<Vec<usize>> = node.input.iter().map(&to_usize_shape).collect();\n        let output_shapes: Vec<Vec<usize>> = node.output.iter().map(&to_usize_shape).collect();\n\n        let cost = jstprove_circuits::api::estimate_op_constraints(\n            &node.op_type,\n            &input_shapes,\n            &output_shapes,\n            &config,\n        );\n        total = total.saturating_add(cost);\n    }\n    total\n}\n\npub fn detect_dim_split(\n    nodes: &[NodeProto],\n    shapes: &HashMap<String, Vec<i64>>,\n    initializer_names: &HashSet<String>,\n    model_opset: i64,\n) -> Option<DimSplitDetection> {\n    let estimated = estimate_slice_constraints(nodes, shapes);\n    if estimated <= MAX_ESTIMATED_CONSTRAINTS {\n        return None;\n    }\n\n    let target_groups = estimated.div_ceil(MAX_ESTIMATED_CONSTRAINTS) as usize;\n\n    for (idx, node) in nodes.iter().enumerate() {\n        if matches!(node.op_type.as_str(), \"MatMul\" | \"Gemm\") {\n            // Gemm with a bias (input C) is not yet supported by the dim-split\n            // template builder; skip so the template construction downstream\n            // stays in sync with the detector.\n            if node.op_type == \"Gemm\" && node.input.get(2).is_some_and(|s: &String| !s.is_empty()) {\n                continue;\n            }\n            // The dim-split runner replaces the entire slice execution with\n            // the patched MatMul template and only writes ds.output_name to\n            // the tensor cache. If this MatMul/Gemm output is consumed by a\n            // later node in the same slice, those downstream ops would never\n            // execute and the slice would publish the wrong tensor. Decline\n            // and let the search continue or fall through to other paths.\n            let Some(node_out) = node.output.first().filter(|s| !s.is_empty()) else {\n                continue;\n            };\n            let consumed_downstream = nodes\n                .iter()\n                .skip(idx + 1)\n                .any(|later| later.input.iter().any(|i| i == node_out));\n            if consumed_downstream {\n                continue;\n            }\n            let Some(weight_name) = node.input.get(1) else {\n                continue;\n            };\n            if !initializer_names.contains(weight_name) {\n                continue;\n            }\n            let Some(weight_shape) = shapes.get(weight_name) else {\n                continue;\n            };\n            if weight_shape.len() != 2 {\n                continue;\n            }\n            // Gemm with transA=1 transposes the activation matrix, which the\n            // single-row sequence tile and the rank-2 template do not model.\n            // Skip so detection stays consistent with the template builder.\n            if node.op_type == \"Gemm\"\n                && super::onnx_proto::get_attribute_int(node, \"transA\").unwrap_or(0) == 1\n            {\n                continue;\n            }\n            let trans_b = node.op_type == \"Gemm\"\n                && super::onnx_proto::get_attribute_int(node, \"transB\").unwrap_or(0) == 1;\n            let k_dim = if trans_b {\n                weight_shape[1] as usize\n            } else {\n                weight_shape[0] as usize\n            };\n            let n_dim = if trans_b {\n                weight_shape[0] as usize\n            } else {\n                weight_shape[1] as usize\n            };\n            let Some(inp_shape) = node.input.first().and_then(|name| shapes.get(name)) else {\n                continue;\n            };\n            let total_rows: usize = inp_shape\n                .iter()\n                .take(inp_shape.len().saturating_sub(1))\n                .map(|&d| d.max(1) as usize)\n                .product();\n            if total_rows == 0 || k_dim == 0 || n_dim == 0 {\n                continue;\n            }\n            let row_cost = k_dim.saturating_mul(n_dim).saturating_mul(2);\n            let max_per_chunk = MAX_ESTIMATED_CONSTRAINTS as usize;\n            // Even with k_chunks == k_dim (chunk_size == 1), the per-chunk\n            // cost is at minimum n_dim * 2. If that alone exceeds the budget\n            // the split is infeasible; let the caller fall through to other\n            // detection paths.\n            if n_dim.saturating_mul(2) > max_per_chunk {\n                continue;\n            }\n            let mut k_chunks = if row_cost > max_per_chunk {\n                row_cost.div_ceil(max_per_chunk).max(1)\n            } else {\n                1\n            };\n            k_chunks = k_chunks.min(k_dim);\n            while k_chunks < k_dim\n                && k_dim\n                    .div_ceil(k_chunks)\n                    .saturating_mul(n_dim)\n                    .saturating_mul(2)\n                    > max_per_chunk\n            {\n                k_chunks += 1;\n            }\n            if total_rows == 1 && k_chunks == 1 {\n                continue;\n            }\n            let Some(input_name) = node.input.first().filter(|s| !s.is_empty()).cloned() else {\n                continue;\n            };\n            let Some(output_name) = node.output.first().filter(|s| !s.is_empty()).cloned() else {\n                continue;\n            };\n            return Some(DimSplitDetection {\n                split_kind: DimSplitKind::MatMulOutputDim,\n                split_dim: 0,\n                dim_size: total_rows,\n                num_groups: total_rows,\n                elements_per_group: 1,\n                input_name,\n                output_name,\n                concat_axis: 0,\n                estimated_constraints: estimated,\n                weight_name: Some(weight_name.clone()),\n                k_dim,\n                n_dim,\n                k_chunks,\n            });\n        }\n    }\n\n    // Slice-boundary inputs are ones not produced by any node inside\n    // this slice; everything else is internal data flow that the\n    // dim-split rewrite cannot honour as a true split axis.\n    let slice_internal_outputs: HashSet<&str> = nodes\n        .iter()\n        .flat_map(|n| n.output.iter())\n        .filter(|s| !s.is_empty())\n        .map(String::as_str)\n        .collect();\n\n    for node in nodes {\n        if node.op_type == \"Softmax\" {\n            let Some(softmax_in) = node.input.first().and_then(|name| shapes.get(name)) else {\n                continue;\n            };\n            if softmax_in.len() != 4 {\n                continue;\n            }\n            // ONNX Softmax default axis: opset >= 13 -> -1\n            // (last axis), opset < 13 -> 1 (channel axis).\n            // unwrap_or(-1) silently mismatches runtime semantics on\n            // opset <13 models that omit the attribute.\n            let default_axis: i64 = if model_opset >= 13 { -1 } else { 1 };\n            let softmax_axis = onnx_proto::get_attribute_int(node, \"axis\").unwrap_or(default_axis);\n            let softmax_axis_abs = if softmax_axis < 0 {\n                (softmax_in.len() as i64 + softmax_axis).max(0) as usize\n            } else {\n                softmax_axis as usize\n            };\n            // Find the attention-block input among the slice inputs:\n            // the first slice-boundary tensor (external -- not the\n            // output of any other node in this slice, and not an\n            // initializer) whose rank matches the softmax input rank\n            // (Q/V-like activation).\n            let attn_input = nodes.iter().flat_map(|n| n.input.iter()).find(|name| {\n                !name.is_empty()\n                    && !initializer_names.contains(name.as_str())\n                    && !slice_internal_outputs.contains(name.as_str())\n                    && shapes.get(*name).is_some_and(|s| s.len() == 4 && s[0] > 0)\n            });\n            let Some(attn_input_name) = attn_input.cloned() else {\n                continue;\n            };\n            let Some(attn_shape) = shapes.get(&attn_input_name) else {\n                continue;\n            };\n            // Choose the dim (among 0..rank) that is not the softmax-reduction\n            // axis and yields the highest axis size; that axis gives the\n            // most groups and the lowest per-group cost.\n            let mut best: Option<(usize, usize, DimSplitKind)> = None;\n            for (d, &axis_len) in attn_shape.iter().enumerate() {\n                if d == softmax_axis_abs {\n                    continue;\n                }\n                let dim_size = axis_len.max(1) as usize;\n                if dim_size < 2 {\n                    continue;\n                }\n                let kind = if d == 1 {\n                    DimSplitKind::HeadDim\n                } else {\n                    DimSplitKind::BatchDim\n                };\n                let better = best.as_ref().is_none_or(|(_, sz, _)| dim_size > *sz);\n                if better {\n                    best = Some((d, dim_size, kind));\n                }\n            }\n            let Some((split_dim, dim_size, split_kind)) = best else {\n                continue;\n            };\n            let num_groups = match smallest_divisor_at_least(dim_size, target_groups) {\n                Some(g) => g,\n                None => continue,\n            };\n            let elements_per_group = dim_size / num_groups;\n            let output_name = nodes\n                .last()\n                .and_then(|n| n.output.first())\n                .filter(|s| !s.is_empty())\n                .cloned()\n                .unwrap_or_else(|| node.output.first().cloned().unwrap_or_default());\n            if output_name.is_empty() {\n                continue;\n            }\n            // Reject the split when axis tracing through the slice cannot\n            // prove the split axis lands at the same position (and size)\n            // in the final output.  Shape-reordering ops (Reshape,\n            // Transpose, Flatten, Squeeze, Unsqueeze, Concat on the\n            // split axis) are non-trivial to follow here, so we require\n            // the output shape to match the attention input at split_dim.\n            let Some(out_shape) = shapes.get(&output_name) else {\n                continue;\n            };\n            if out_shape.len() != attn_shape.len() || out_shape[split_dim] != attn_shape[split_dim]\n            {\n                continue;\n            }\n            return Some(DimSplitDetection {\n                split_kind,\n                split_dim,\n                dim_size,\n                num_groups,\n                elements_per_group,\n                input_name: attn_input_name,\n                output_name,\n                concat_axis: split_dim,\n                estimated_constraints: estimated,\n                weight_name: None,\n                k_dim: 0,\n                n_dim: 0,\n                k_chunks: 1,\n            });\n        }\n    }\n\n    let first_non_init_input = nodes.first().and_then(|n| {\n        n.input\n            .iter()\n            .find(|name| !name.is_empty() && !initializer_names.contains(name.as_str()))\n    });\n    let first_input_shape = first_non_init_input.and_then(|name| shapes.get(name))?;\n    if first_input_shape.is_empty() {\n        return None;\n    }\n\n    // Conv / ConvTranspose / Pooling are not separable along arbitrary\n    // input axes: splitting the input channel or the spatial dimensions\n    // produces semantically incorrect per-group outputs. The dedicated\n    // detection paths (conv spatial tiling, channel splitting) handle\n    // these ops correctly; this generic fallback refuses to emit a\n    // split for them.  MatMul / Gemm are *not* listed here: their\n    // dedicated dim-split-k path handles the K-axis split when the\n    // weight is an initializer, but non-terminal MatMul/Gemm slices or\n    // slices whose weight is a runtime tensor still benefit from the\n    // generic axis-0 (batch) fallback, which is always semantically\n    // sound because the batch dimension is independent across rows.\n    for node in nodes {\n        if matches!(\n            node.op_type.as_str(),\n            \"Conv\"\n                | \"ConvTranspose\"\n                | \"AveragePool\"\n                | \"MaxPool\"\n                | \"GlobalAveragePool\"\n                | \"GlobalMaxPool\"\n                | \"LRN\"\n        ) {\n            return None;\n        }\n    }\n\n    // Find the deepest split_dim that is still compatible with every\n    // normalization-style op in the slice.  Splitting a later axis produces\n    // more groups and a smaller per-group cost without violating op semantics.\n    let rank = first_input_shape.len();\n    // If the slice contains any axis-reordering op (Transpose) AND any\n    // axis-sensitive normalization op (LayerNormalization / Softmax),\n    // we can no longer cheaply trace which axis the normalization\n    // really runs on after the reorder.  Restrict the split to axis 0\n    // (always the batch dim, always semantically sound) so we never\n    // emit a split that lands on the post-Transpose normalization axis.\n    let has_transpose = nodes.iter().any(|n| n.op_type == \"Transpose\");\n    let has_norm = nodes.iter().any(|n| {\n        matches!(\n            n.op_type.as_str(),\n            \"LayerNormalization\" | \"Softmax\" | \"LogSoftmax\"\n        )\n    });\n    let mut max_allowed = if has_transpose && has_norm { 1 } else { rank };\n    for node in nodes {\n        match node.op_type.as_str() {\n            \"LayerNormalization\" => {\n                let axis = onnx_proto::get_attribute_int(node, \"axis\").unwrap_or(-1);\n                let resolved = if axis < 0 {\n                    (rank as i64 + axis).max(0) as usize\n                } else {\n                    (axis as usize).min(rank)\n                };\n                if resolved < max_allowed {\n                    max_allowed = resolved;\n                }\n            }\n            \"Softmax\" | \"LogSoftmax\" => {\n                // ONNX Softmax / LogSoftmax default axis: opset >=\n                // 13 -> -1 (last axis), opset < 13 -> 1 (channel\n                // axis).  unwrap_or(-1) silently mismatches runtime\n                // semantics on opset < 13 models that omit the\n                // attribute.\n                let default_axis: i64 = if model_opset >= 13 { -1 } else { 1 };\n                let axis = onnx_proto::get_attribute_int(node, \"axis\").unwrap_or(default_axis);\n                let resolved = if axis < 0 {\n                    (rank as i64 + axis).max(0) as usize\n                } else {\n                    (axis as usize).min(rank.saturating_sub(1))\n                };\n                if resolved < max_allowed {\n                    max_allowed = resolved;\n                }\n            }\n            \"BatchNormalization\" => {\n                // BatchNorm couples every spatial element to the\n                // running mean / variance per channel; splitting any\n                // axis would change those statistics.  Force-reject\n                // the dim-split unconditionally so the early return at\n                // line 1044 fires regardless of prior state.\n                max_allowed = 0;\n            }\n            _ => {}\n        }\n    }\n    if max_allowed == 0 {\n        return None;\n    }\n\n    let mut best: Option<(usize, usize)> = None;\n    for (d, &axis_len) in first_input_shape.iter().enumerate().take(max_allowed) {\n        let dim = axis_len.max(1) as usize;\n        if dim <= 1 {\n            continue;\n        }\n        if best.map(|(_, size)| dim > size).unwrap_or(true) {\n            best = Some((d, dim));\n        }\n    }\n    let (split_dim, dim_size) = best?;\n\n    let num_groups = smallest_divisor_at_least(dim_size, target_groups)?;\n    let elements_per_group = dim_size / num_groups;\n    let input_name = first_non_init_input.cloned()?;\n    let output_name = nodes\n        .last()\n        .and_then(|n| n.output.first())\n        .filter(|s| !s.is_empty())\n        .cloned()?;\n    // Require the final output shape to preserve rank and the split\n    // axis size; otherwise an intermediate op (Reshape, Transpose,\n    // Flatten, Squeeze, Unsqueeze) has reordered the axes and\n    // concat_axis=split_dim would splice the groups into the wrong\n    // output dimension.  Tracing the axis through an arbitrary chain\n    // of shape ops is out of scope here, so we conservatively reject.\n    let out_shape = shapes.get(&output_name)?;\n    if out_shape.len() != first_input_shape.len()\n        || out_shape[split_dim] != first_input_shape[split_dim]\n    {\n        return None;\n    }\n    Some(DimSplitDetection {\n        split_kind: DimSplitKind::BatchDim,\n        split_dim,\n        dim_size,\n        num_groups,\n        elements_per_group,\n        input_name,\n        output_name,\n        concat_axis: split_dim,\n        estimated_constraints: estimated,\n        weight_name: None,\n        k_dim: 0,\n        n_dim: 0,\n        k_chunks: 1,\n    })\n}\n\n#[derive(Debug, Clone)]\npub enum TilingDetection {\n    Spatial {\n        input_name: String,\n        output_name: String,\n        input_names: Vec<String>,\n        ndim: i64,\n        c_in: i64,\n        c_out: i64,\n        h: i64,\n        w: i64,\n        tile_size: i64,\n        halo: [i64; 4],\n        tiles_y: i64,\n        tiles_x: i64,\n        out_tile: [i64; 2],\n        stride: [i64; 2],\n    },\n    ChannelSplit {\n        input_name: String,\n        output_name: String,\n        c_in: i64,\n        c_out: i64,\n        h: i64,\n        w: i64,\n        num_groups: i64,\n        channels_per_group: i64,\n    },\n    FixedSegment {\n        input_name: String,\n        output_name: String,\n        input_names: Vec<String>,\n        total_elements: i64,\n        segment_size: i64,\n        num_segments: i64,\n        original_shape: Vec<i64>,\n    },\n}\n\nstruct SpatialTileGeometry {\n    c_in: i64,\n    c_out: i64,\n    tile_h: i64,\n    tile_w: i64,\n    out_h: i64,\n    out_w: i64,\n}\n\nfn compute_spatial_tile_geometry(\n    graph: &GraphProto,\n    pads: [i64; 4],\n    kernel: [i64; 2],\n    dilation: [i64; 2],\n    stride: [i64; 2],\n    tile_size: i64,\n    c_out_override: Option<i64>,\n) -> Result<SpatialTileGeometry> {\n    let halo = compute_halo_size(pads).ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\"spatial tile: invalid pad values\".to_string())\n    })?;\n    let tile_h = tile_size\n        .checked_add(halo[0])\n        .and_then(|v| v.checked_add(halo[2]))\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(format!(\n                \"spatial tile: tile_h overflow (tile_size={tile_size}, halo={:?})\",\n                halo\n            ))\n        })?;\n    let tile_w = tile_size\n        .checked_add(halo[1])\n        .and_then(|v| v.checked_add(halo[3]))\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(format!(\n                \"spatial tile: tile_w overflow (tile_size={tile_size}, halo={:?})\",\n                halo\n            ))\n        })?;\n    let (out_h, out_w) = conv_output_hw(tile_h, tile_w, [0, 0, 0, 0], kernel, dilation, stride)\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(format!(\n                \"spatial tile: invalid output dims for tile_h={tile_h}, tile_w={tile_w}, stride={stride:?}, kernel={kernel:?}\"\n            ))\n        })?;\n    let c_in = graph\n        .input\n        .first()\n        .map(onnx_proto::vi_shape)\n        .and_then(|s| (s.len() == 4 && s[1] > 0).then_some(s[1]))\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(\n                \"spatial tile: unable to determine input channels\".to_string(),\n            )\n        })?;\n    let c_out = c_out_override.unwrap_or(c_in);\n    Ok(SpatialTileGeometry {\n        c_in,\n        c_out,\n        tile_h,\n        tile_w,\n        out_h,\n        out_w,\n    })\n}\n\nstruct TileModelSpec {\n    nodes: Vec<NodeProto>,\n    input: onnx_proto::ValueInfoProto,\n    output: onnx_proto::ValueInfoProto,\n    initializers: Vec<onnx_proto::TensorProto>,\n    out_hw: [i64; 2],\n}\n\nfn save_tile_model(\n    model: &ModelProto,\n    spec: TileModelSpec,\n    slice_idx: usize,\n    output_dir: &Path,\n) -> Result<TileSliceResult> {\n    let graph = onnx_proto::make_graph(\n        &format!(\"tile_{slice_idx}\"),\n        spec.nodes,\n        vec![spec.input],\n        vec![spec.output],\n        spec.initializers,\n    );\n    let tile_model = onnx_proto::make_model(graph, model_opset(model));\n    let tiles_dir = output_dir.join(\"tiles\");\n    std::fs::create_dir_all(&tiles_dir)\n        .map_err(|e| crate::error::DsperseError::io(e, &tiles_dir))?;\n    let onnx_path = tiles_dir.join(\"tile.onnx\");\n    onnx_proto::save_model(&tile_model, &onnx_path)?;\n    Ok(TileSliceResult {\n        path: format!(\"slice_{slice_idx}/payload/tiles/tile.onnx\"),\n        conv_out: spec.out_hw,\n    })\n}\n\npub fn create_tile_slice(\n    model: &ModelProto,\n    tile_size: i64,\n    slice_idx: usize,\n    output_dir: &Path,\n) -> Result<TileSliceResult> {\n    if tile_size <= 0 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_tile_slice: tile_size must be > 0, got {tile_size}\"\n        )));\n    }\n    let SlicePrologue {\n        graph,\n        cp,\n        weights,\n        bias,\n    } = extract_slice_prologue(model).ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"create_tile_slice: failed to extract slice prologue\".to_string(),\n        )\n    })?;\n    let conv_node = &graph.node[cp.node_idx];\n    let weights = weights.ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\"create_tile_slice: conv weights not found\".to_string())\n    })?;\n\n    let cfg_c_in = cp.c_in.checked_mul(cp.group).filter(|&v| v > 0);\n    let geom = compute_spatial_tile_geometry(\n        graph,\n        cp.pads,\n        cp.kernel,\n        cp.dilation,\n        cp.stride,\n        tile_size,\n        Some(weights.dims[0]),\n    )?;\n    if let Some(c) = cfg_c_in\n        && geom.c_in != c\n    {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_tile_slice: graph c_in ({}) != weight c_in*group ({c})\",\n            geom.c_in\n        )));\n    }\n\n    let x = onnx_proto::make_tensor_value_info(\n        \"tile_in\",\n        TensorProto::FLOAT,\n        &[1, geom.c_in, geom.tile_h, geom.tile_w],\n    );\n    let y = onnx_proto::make_tensor_value_info(\n        \"tile_out\",\n        TensorProto::FLOAT,\n        &[1, geom.c_out, geom.out_h, geom.out_w],\n    );\n\n    let mut initializers = vec![onnx_proto::make_tensor(\n        \"W\",\n        TensorProto::FLOAT,\n        &weights.dims,\n        weights.data,\n    )];\n    let mut conv_inputs = vec![\"tile_in\".to_string(), \"W\".to_string()];\n\n    if let Some(bias_data) = &bias {\n        let bias_dims = [geom.c_out];\n        initializers.push(onnx_proto::make_tensor(\n            \"B\",\n            TensorProto::FLOAT,\n            &bias_dims,\n            bias_data.clone(),\n        ));\n        conv_inputs.push(\"B\".to_string());\n    }\n\n    let mut conv_attrs = vec![\n        onnx_proto::make_attribute_ints(\"kernel_shape\", &cp.kernel),\n        onnx_proto::make_attribute_ints(\"strides\", &cp.stride),\n        onnx_proto::make_attribute_ints(\"pads\", &[0, 0, 0, 0]),\n        onnx_proto::make_attribute_ints(\"dilations\", &cp.dilation),\n    ];\n    if cp.group != 1 {\n        conv_attrs.push(onnx_proto::make_attribute_int(\"group\", cp.group));\n    }\n\n    let mut nodes = vec![onnx_proto::make_node(\n        \"Conv\",\n        conv_inputs,\n        vec![\"conv_out\".to_string()],\n        conv_attrs,\n    )];\n\n    integrate_extra_ops(graph, conv_node, &mut initializers, &mut nodes)?;\n\n    save_tile_model(\n        model,\n        TileModelSpec {\n            nodes,\n            input: x,\n            output: y,\n            initializers,\n            out_hw: [geom.out_h, geom.out_w],\n        },\n        slice_idx,\n        output_dir,\n    )\n}\n\npub fn create_pool_tile_slice(\n    model: &ModelProto,\n    tile_size: i64,\n    slice_idx: usize,\n    output_dir: &Path,\n) -> Result<TileSliceResult> {\n    if tile_size <= 0 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_pool_tile_slice: tile_size must be > 0, got {tile_size}\"\n        )));\n    }\n    let graph = model.graph.as_ref().ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"create_pool_tile_slice: model.graph is None\".to_string(),\n        )\n    })?;\n    let pp = get_pool_params(graph).ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"create_pool_tile_slice: no MaxPool node found\".to_string(),\n        )\n    })?;\n    let pool_node = &graph.node[pp.node_idx];\n\n    let geom = compute_spatial_tile_geometry(\n        graph,\n        pp.pads,\n        pp.kernel,\n        pp.dilation,\n        pp.stride,\n        tile_size,\n        None,\n    )?;\n\n    let x = onnx_proto::make_tensor_value_info(\n        \"tile_in\",\n        TensorProto::FLOAT,\n        &[1, geom.c_in, geom.tile_h, geom.tile_w],\n    );\n    let y = onnx_proto::make_tensor_value_info(\n        \"tile_out\",\n        TensorProto::FLOAT,\n        &[1, geom.c_out, geom.out_h, geom.out_w],\n    );\n\n    let pool_attrs = vec![\n        onnx_proto::make_attribute_ints(\"kernel_shape\", &pp.kernel),\n        onnx_proto::make_attribute_ints(\"strides\", &pp.stride),\n        onnx_proto::make_attribute_ints(\"pads\", &[0, 0, 0, 0]),\n        onnx_proto::make_attribute_ints(\"dilations\", &pp.dilation),\n    ];\n    let mut nodes = vec![onnx_proto::make_node(\n        \"MaxPool\",\n        vec![\"tile_in\".to_string()],\n        vec![\"pool_out\".to_string()],\n        pool_attrs,\n    )];\n\n    let mut initializers = Vec::new();\n    integrate_extra_ops(graph, pool_node, &mut initializers, &mut nodes)?;\n\n    save_tile_model(\n        model,\n        TileModelSpec {\n            nodes,\n            input: x,\n            output: y,\n            initializers,\n            out_hw: [geom.out_h, geom.out_w],\n        },\n        slice_idx,\n        output_dir,\n    )\n}\n\nfn integrate_extra_ops(\n    graph: &GraphProto,\n    primary_node: &NodeProto,\n    initializers: &mut Vec<onnx_proto::TensorProto>,\n    nodes: &mut Vec<NodeProto>,\n) -> crate::error::Result<()> {\n    let primary_op = primary_node.op_type.as_str();\n    let orig_input_name = graph.input.first().map(|i| i.name.as_str()).unwrap_or(\"\");\n\n    let extra: Vec<&NodeProto> = graph\n        .node\n        .iter()\n        .filter(|n| n.op_type != primary_op)\n        .collect();\n\n    if extra.is_empty() {\n        let last = nodes.last_mut().ok_or_else(|| {\n            crate::error::DsperseError::Slicer(\n                \"integrate_extra_ops: no nodes to set output on\".into(),\n            )\n        })?;\n        let out = last.output.get_mut(0).ok_or_else(|| {\n            crate::error::DsperseError::Slicer(\n                \"integrate_extra_ops: last node has no outputs\".into(),\n            )\n        })?;\n        *out = \"tile_out\".to_string();\n        return Ok(());\n    }\n\n    let mut primary_weight_names: HashSet<String> = HashSet::new();\n    for inp in primary_node.input.iter().skip(1) {\n        primary_weight_names.insert(inp.clone());\n    }\n\n    for init in &graph.initializer {\n        if !primary_weight_names.contains(&init.name) {\n            initializers.push(init.clone());\n        }\n    }\n\n    let primary_outputs: HashSet<String> = graph\n        .node\n        .iter()\n        .filter(|n| n.op_type == primary_op)\n        .flat_map(|n| n.output.iter().cloned())\n        .collect();\n\n    let primary_out_wire = nodes\n        .last()\n        .and_then(|n| n.output.first())\n        .cloned()\n        .unwrap_or_else(|| format!(\"{}_out\", primary_op.to_lowercase()));\n\n    for (i, orig_node) in extra.iter().enumerate() {\n        let new_inputs: Vec<String> = orig_node\n            .input\n            .iter()\n            .map(|inp| {\n                if primary_outputs.contains(inp) {\n                    primary_out_wire.clone()\n                } else if inp == orig_input_name {\n                    \"tile_in\".to_string()\n                } else {\n                    inp.clone()\n                }\n            })\n            .collect();\n\n        let is_last = i == extra.len() - 1;\n        let new_outputs = if is_last {\n            vec![\"tile_out\".to_string()]\n        } else {\n            orig_node.output.clone()\n        };\n\n        nodes.push(NodeProto {\n            op_type: orig_node.op_type.clone(),\n            input: new_inputs,\n            output: new_outputs,\n            attribute: orig_node.attribute.clone(),\n            name: String::new(),\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        });\n    }\n\n    Ok(())\n}\n\n#[allow(clippy::too_many_arguments)]\nfn create_channel_group_slice(\n    model: &ModelProto,\n    prologue: &SlicePrologue<'_>,\n    group_idx: usize,\n    c_start: i64,\n    c_end: i64,\n    h_in: i64,\n    w_in: i64,\n    slice_idx: usize,\n    output_dir: &Path,\n) -> Result<ChannelGroupInfo> {\n    let cp = &prologue.cp;\n    if c_start < 0 || c_end < 0 || c_start >= c_end {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_channel_group_slice: invalid channel range c_start={c_start}, c_end={c_end}\"\n        )));\n    }\n    let weights = prologue.weights.as_ref().ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"create_channel_group_slice: conv weights not found\".to_string(),\n        )\n    })?;\n\n    let c_group = c_end - c_start;\n    let (h_out, w_out) = conv_output_hw(h_in, w_in, cp.pads, cp.kernel, cp.dilation, cp.stride)\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(format!(\n                \"create_channel_group_slice: invalid output dims for h_in={h_in}, w_in={w_in}\"\n            ))\n        })?;\n    let c_out = cp.c_out;\n\n    let input_name = format!(\"group_{group_idx}_in\");\n    let output_name = format!(\"group_{group_idx}_out\");\n\n    let x = onnx_proto::make_tensor_value_info(\n        &input_name,\n        TensorProto::FLOAT,\n        &[1, c_group, h_in, w_in],\n    );\n    let y = onnx_proto::make_tensor_value_info(\n        &output_name,\n        TensorProto::FLOAT,\n        &[1, c_out, h_out, w_out],\n    );\n\n    let c_start_uz = i64_to_usize(c_start, \"create_channel_group_slice\", \"c_start\")?;\n    let c_end_uz = i64_to_usize(c_end, \"create_channel_group_slice\", \"c_end\")?;\n    let sliced_weights = slice_weights(weights, c_start_uz, c_end_uz)?;\n\n    let w_tensor = onnx_proto::make_tensor(\n        \"W\",\n        TensorProto::FLOAT,\n        &sliced_weights.dims,\n        sliced_weights.data,\n    );\n\n    let mut conv_attrs = vec![\n        onnx_proto::make_attribute_ints(\"kernel_shape\", &cp.kernel),\n        onnx_proto::make_attribute_ints(\"strides\", &cp.stride),\n        onnx_proto::make_attribute_ints(\"pads\", &cp.pads),\n        onnx_proto::make_attribute_ints(\"dilations\", &cp.dilation),\n    ];\n    if cp.group != 1 {\n        conv_attrs.push(onnx_proto::make_attribute_int(\"group\", cp.group));\n    }\n\n    let node = onnx_proto::make_node(\n        \"Conv\",\n        vec![input_name, \"W\".to_string()],\n        vec![output_name],\n        conv_attrs,\n    );\n\n    let graph_proto = onnx_proto::make_graph(\n        &format!(\"channel_group_{slice_idx}_{group_idx}\"),\n        vec![node],\n        vec![x],\n        vec![y],\n        vec![w_tensor],\n    );\n    let group_model = onnx_proto::make_model(graph_proto, model_opset(model));\n\n    let groups_dir = output_dir.join(\"channel_groups\");\n    std::fs::create_dir_all(&groups_dir)\n        .map_err(|e| crate::error::DsperseError::io(e, &groups_dir))?;\n    let onnx_path = groups_dir.join(format!(\"group_{group_idx}.onnx\"));\n    onnx_proto::save_model(&group_model, &onnx_path)?;\n\n    Ok(ChannelGroupInfo {\n        group_idx,\n        c_start: c_start_uz,\n        c_end: c_end_uz,\n        path: format!(\"slice_{slice_idx}/payload/channel_groups/group_{group_idx}.onnx\"),\n        jstprove_circuit_path: None,\n        jstprove_settings_path: None,\n    })\n}\n\nfn i64_to_usize(val: i64, ctx: &str, name: &str) -> Result<usize> {\n    usize::try_from(val).map_err(|_| {\n        crate::error::DsperseError::Slicer(format!(\"{ctx}: {name} ({val}) out of range for usize\"))\n    })\n}\n\nfn checked_dim_product(factors: &[usize]) -> Result<usize> {\n    factors.iter().try_fold(1usize, |acc, &f| {\n        acc.checked_mul(f).ok_or_else(|| {\n            crate::error::DsperseError::Slicer(format!(\n                \"slice_weights: dimension product overflow (factors={factors:?})\"\n            ))\n        })\n    })\n}\n\nfn slice_weights(weights: &WeightInfo, c_start: usize, c_end: usize) -> Result<WeightInfo> {\n    if weights.dims.len() < 4 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"slice_weights: expected >= 4 dims, got {}\",\n            weights.dims.len()\n        )));\n    }\n    let to_usize = |dim: i64, name: &str| -> Result<usize> {\n        usize::try_from(dim).map_err(|_| {\n            crate::error::DsperseError::Slicer(format!(\n                \"slice_weights: {name} dimension {dim} is negative or too large\"\n            ))\n        })\n    };\n    let c_out = to_usize(weights.dims[0], \"c_out\")?;\n    let c_in = to_usize(weights.dims[1], \"c_in\")?;\n    let kh = to_usize(weights.dims[2], \"kh\")?;\n    let kw = to_usize(weights.dims[3], \"kw\")?;\n    let expected_len = checked_dim_product(&[c_out, c_in, kh, kw])?;\n    if weights.data.len() != expected_len {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"slice_weights: data length {} != expected {} (dims={:?})\",\n            weights.data.len(),\n            expected_len,\n            weights.dims\n        )));\n    }\n    if c_start >= c_end {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"slice_weights: c_start ({c_start}) >= c_end ({c_end})\"\n        )));\n    }\n    if c_end > c_in {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"slice_weights: c_end ({c_end}) exceeds c_in ({c_in})\"\n        )));\n    }\n    let c_group = c_end - c_start;\n    let capacity = checked_dim_product(&[c_out, c_group, kh, kw])?;\n    let stride_cin = checked_dim_product(&[c_in, kh, kw])?;\n    let stride_kh = checked_dim_product(&[kh, kw])?;\n\n    let mut sliced = Vec::with_capacity(capacity);\n    for o in 0..c_out {\n        for c in c_start..c_end {\n            for h in 0..kh {\n                for w_idx in 0..kw {\n                    let idx = o * stride_cin + c * stride_kh + h * kw + w_idx;\n                    sliced.push(weights.data[idx]);\n                }\n            }\n        }\n    }\n\n    Ok(WeightInfo {\n        data: sliced,\n        dims: vec![c_out as i64, c_group as i64, kh as i64, kw as i64],\n    })\n}\n\nfn save_conv_bias(\n    prologue: &SlicePrologue<'_>,\n    slice_idx: usize,\n    output_dir: &Path,\n) -> Result<Option<String>> {\n    let Some(bias_data) = &prologue.bias else {\n        return Ok(None);\n    };\n\n    let groups_dir = output_dir.join(\"channel_groups\");\n    std::fs::create_dir_all(&groups_dir)\n        .map_err(|e| crate::error::DsperseError::io(e, &groups_dir))?;\n\n    let bias_bytes = rmp_serde::to_vec_named(&bias_data)?;\n    let bias_path = groups_dir.join(\"bias.msgpack\");\n    std::fs::write(&bias_path, bias_bytes)\n        .map_err(|e| crate::error::DsperseError::io(e, &bias_path))?;\n\n    Ok(Some(format!(\n        \"slice_{slice_idx}/payload/channel_groups/bias.msgpack\"\n    )))\n}\n\n#[allow(clippy::too_many_arguments)]\npub fn apply_channel_splitting(\n    model: &ModelProto,\n    cfg: &ChannelSplitParams,\n    input_name: &str,\n    output_name: &str,\n    output_dir: &Path,\n) -> Result<ChannelSplitInfo> {\n    let &ChannelSplitParams {\n        c_in,\n        c_out,\n        num_groups,\n        channels_per_group,\n        h,\n        w,\n        slice_idx,\n    } = cfg;\n    if c_in <= 0 || c_out <= 0 || num_groups <= 0 || channels_per_group <= 0 || h <= 0 || w <= 0 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"apply_channel_splitting: invalid ChannelSplitParams (c_in={c_in}, c_out={c_out}, num_groups={num_groups}, channels_per_group={channels_per_group}, h={h}, w={w})\"\n        )));\n    }\n    let covered = num_groups.checked_mul(channels_per_group).ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"apply_channel_splitting: num_groups * channels_per_group overflow\".to_string(),\n        )\n    })?;\n    if covered < c_in {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"apply_channel_splitting: cfg covers only {covered} input channels, expected at least {c_in}\",\n        )));\n    }\n    let last_group_start = (num_groups - 1)\n        .checked_mul(channels_per_group)\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(\n                \"apply_channel_splitting: group start computation overflow\".to_string(),\n            )\n        })?;\n    if last_group_start >= c_in {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"apply_channel_splitting: cfg creates empty trailing groups (last_start={last_group_start}, c_in={c_in})\"\n        )));\n    }\n    let prologue = extract_slice_prologue(model).ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"apply_channel_splitting: failed to extract slice prologue from model\".to_string(),\n        )\n    })?;\n\n    let (_, _, model_c_in, model_h, model_w) =\n        get_model_dimensions(prologue.graph).ok_or_else(|| {\n            crate::error::DsperseError::Slicer(\n                \"apply_channel_splitting: unable to determine model dimensions\".to_string(),\n            )\n        })?;\n    let model_c_out = prologue.cp.c_out;\n    if prologue.cp.group != 1 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"apply_channel_splitting: unsupported Conv group={}, expected 1\",\n            prologue.cp.group\n        )));\n    }\n    if prologue.cp.c_in != model_c_in {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"apply_channel_splitting: weight/model c_in mismatch (weights c_in={}, model c_in={})\",\n            prologue.cp.c_in, model_c_in\n        )));\n    }\n    if model_c_in != c_in || model_c_out != c_out || model_h != h || model_w != w {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"apply_channel_splitting: cfg dims (c_in={c_in}, c_out={c_out}, h={h}, w={w}) mismatch model dims (c_in={model_c_in}, c_out={model_c_out}, h={model_h}, w={model_w})\"\n        )));\n    }\n\n    let (out_h, out_w) = conv_output_hw(\n        h,\n        w,\n        prologue.cp.pads,\n        prologue.cp.kernel,\n        prologue.cp.dilation,\n        prologue.cp.stride,\n    )\n    .ok_or_else(|| {\n        crate::error::DsperseError::Slicer(format!(\n            \"apply_channel_splitting: invalid conv output dimensions for h={h}, w={w}, stride={:?}, kernel={:?}\",\n            prologue.cp.stride, prologue.cp.kernel\n        ))\n    })?;\n\n    let groups_dir = output_dir.join(\"channel_groups\");\n    let cleanup = || {\n        if groups_dir.exists() {\n            let _ = std::fs::remove_dir_all(&groups_dir);\n        }\n    };\n\n    let mut groups = Vec::new();\n    for g in 0..num_groups {\n        let c_start = g * channels_per_group;\n        let c_end = ((g + 1) * channels_per_group).min(c_in);\n\n        let g_uz = i64_to_usize(g, \"apply_channel_splitting\", \"group_idx\").inspect_err(|_| {\n            cleanup();\n        })?;\n        let group_info = match create_channel_group_slice(\n            model, &prologue, g_uz, c_start, c_end, h, w, slice_idx, output_dir,\n        ) {\n            Ok(info) => info,\n            Err(e) => {\n                cleanup();\n                return Err(e);\n            }\n        };\n        groups.push(group_info);\n    }\n\n    let bias_path = match save_conv_bias(&prologue, slice_idx, output_dir) {\n        Ok(p) => p,\n        Err(e) => {\n            cleanup();\n            return Err(e);\n        }\n    };\n\n    let ctx = \"apply_channel_splitting\";\n    let c_in_uz = i64_to_usize(c_in, ctx, \"c_in\").inspect_err(|_| cleanup())?;\n    let c_out_uz = i64_to_usize(c_out, ctx, \"c_out\").inspect_err(|_| cleanup())?;\n    let num_groups_uz = i64_to_usize(num_groups, ctx, \"num_groups\").inspect_err(|_| cleanup())?;\n    let cpg_uz =\n        i64_to_usize(channels_per_group, ctx, \"channels_per_group\").inspect_err(|_| cleanup())?;\n    let h_uz = i64_to_usize(h, ctx, \"h\").inspect_err(|_| cleanup())?;\n    let w_uz = i64_to_usize(w, ctx, \"w\").inspect_err(|_| cleanup())?;\n    let out_h_uz = i64_to_usize(out_h, ctx, \"out_h\").inspect_err(|_| cleanup())?;\n    let out_w_uz = i64_to_usize(out_w, ctx, \"out_w\").inspect_err(|_| cleanup())?;\n    Ok(ChannelSplitInfo {\n        slice_idx,\n        c_in: c_in_uz,\n        c_out: c_out_uz,\n        num_groups: num_groups_uz,\n        channels_per_group: cpg_uz,\n        input_name: input_name.to_string(),\n        output_name: output_name.to_string(),\n        h: h_uz,\n        w: w_uz,\n        out_h: out_h_uz,\n        out_w: out_w_uz,\n        groups,\n        bias_path,\n    })\n}\n\npub fn create_dim_split_template(\n    model: &ModelProto,\n    info: &crate::schema::tiling::DimSplitInfo,\n    output_dir: &Path,\n    traced_shapes: Option<&HashMap<String, Vec<i64>>>,\n) -> Result<std::path::PathBuf> {\n    let graph = model.graph.as_ref().ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\"create_dim_split_template: model has no graph\".into())\n    })?;\n\n    match info.split_kind {\n        crate::schema::tiling::DimSplitKind::MatMulOutputDim => {\n            create_matmul_dim_template(model, graph, info, output_dir)\n        }\n        crate::schema::tiling::DimSplitKind::HeadDim\n        | crate::schema::tiling::DimSplitKind::BatchDim => {\n            create_generic_dim_template(model, graph, info, output_dir, traced_shapes)\n        }\n    }\n}\n\nfn create_matmul_dim_template(\n    model: &ModelProto,\n    graph: &GraphProto,\n    info: &crate::schema::tiling::DimSplitInfo,\n    output_dir: &Path,\n) -> Result<std::path::PathBuf> {\n    let weight_name = info.weight_name.as_ref().ok_or_else(|| {\n        crate::error::DsperseError::Slicer(format!(\n            \"create_matmul_dim_template: slice {} DimSplitInfo missing weight_name\",\n            info.slice_idx\n        ))\n    })?;\n\n    // Match the exact split node by weight, activation input, and output\n    // name. A graph may reuse the same weight initializer in multiple\n    // MatMul/Gemm ops (tied weights, weight sharing across heads); without\n    // checking IO we could bind the wrong op and emit a template that\n    // doesn't match the slice the runner will execute.\n    let matmul_node = graph\n        .node\n        .iter()\n        .find(|n| {\n            matches!(n.op_type.as_str(), \"MatMul\" | \"Gemm\")\n                && n.input.iter().any(|i| i == weight_name)\n                && n.input.iter().any(|i| i == &info.input_name)\n                && n.output.iter().any(|o| o == &info.output_name)\n        })\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(format!(\n                \"create_matmul_dim_template: slice {} no MatMul/Gemm matches weight={weight_name:?} input={:?} output={:?}\",\n                info.slice_idx, info.input_name, info.output_name\n            ))\n        })?;\n\n    if matmul_node.op_type == \"Gemm\" && matmul_node.input.get(2).is_some_and(|s| !s.is_empty()) {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_matmul_dim_template: slice {} Gemm with bias not supported for dim-split\",\n            info.slice_idx\n        )));\n    }\n\n    let weight_tensor = graph\n        .initializer\n        .iter()\n        .find(|i| i.name == *weight_name)\n        .ok_or_else(|| {\n            crate::error::DsperseError::Slicer(format!(\n                \"create_matmul_dim_template: weight {weight_name:?} not in initializers\"\n            ))\n        })?;\n    if weight_tensor.dims.len() != 2 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_matmul_dim_template: expected 2D weights, got {:?}\",\n            weight_tensor.dims\n        )));\n    }\n\n    if matmul_node.op_type == \"Gemm\"\n        && onnx_proto::get_attribute_int(matmul_node, \"transA\").unwrap_or(0) == 1\n    {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_matmul_dim_template: slice {} Gemm with transA=1 is not supported for dim-split\",\n            info.slice_idx\n        )));\n    }\n\n    let trans_b = matmul_node.op_type == \"Gemm\"\n        && onnx_proto::get_attribute_int(matmul_node, \"transB\").unwrap_or(0) == 1;\n\n    let (rows, cols) = (\n        weight_tensor.dims[0] as usize,\n        weight_tensor.dims[1] as usize,\n    );\n    let (k_dim, n_dim) = if trans_b { (cols, rows) } else { (rows, cols) };\n    let k_chunk_size = k_dim.div_ceil(info.k_chunks.max(1));\n\n    let tmpl_input_name = \"dim_tmpl_in\".to_string();\n    let tmpl_output_name = \"dim_tmpl_out\".to_string();\n    let tmpl_weight_name = \"W\".to_string();\n\n    let tmpl_input_shape: Vec<i64> = vec![1, k_chunk_size as i64];\n    let output_shape: Vec<i64> = vec![1, n_dim as i64];\n\n    let x =\n        onnx_proto::make_tensor_value_info(&tmpl_input_name, TensorProto::FLOAT, &tmpl_input_shape);\n    let y =\n        onnx_proto::make_tensor_value_info(&tmpl_output_name, TensorProto::FLOAT, &output_shape);\n    let tmpl_weight_dims: Vec<i64> = if trans_b {\n        vec![n_dim as i64, k_chunk_size as i64]\n    } else {\n        vec![k_chunk_size as i64, n_dim as i64]\n    };\n    let w = onnx_proto::make_tensor(\n        &tmpl_weight_name,\n        TensorProto::FLOAT,\n        &tmpl_weight_dims,\n        vec![0.0f32; k_chunk_size * n_dim],\n    );\n\n    let mut attrs = Vec::new();\n    let node_inputs = vec![tmpl_input_name, tmpl_weight_name];\n    let initializers = vec![w];\n\n    if matmul_node.op_type == \"Gemm\" {\n        if let Some(alpha) = onnx_proto::get_attribute_float(matmul_node, \"alpha\") {\n            attrs.push(onnx_proto::make_attribute_float(\"alpha\", alpha));\n        }\n        if let Some(beta) = onnx_proto::get_attribute_float(matmul_node, \"beta\") {\n            attrs.push(onnx_proto::make_attribute_float(\"beta\", beta));\n        }\n        // transA is rejected above; the template always uses A non-transposed.\n        if trans_b {\n            attrs.push(onnx_proto::make_attribute_int(\"transB\", 1));\n        }\n        // Biased Gemm is rejected above, so no C initializer is ever folded\n        // into the template.\n    }\n\n    let node = onnx_proto::make_node(\n        &matmul_node.op_type,\n        node_inputs,\n        vec![tmpl_output_name],\n        attrs,\n    );\n\n    let graph_proto = onnx_proto::make_graph(\n        &format!(\"dim_template_{}\", info.slice_idx),\n        vec![node],\n        vec![x],\n        vec![y],\n        initializers,\n    );\n    let tmpl_model = onnx_proto::make_model(graph_proto, model_opset(model));\n\n    let tmpl_path = output_dir.join(\"dim_template.onnx\");\n    onnx_proto::save_model(&tmpl_model, &tmpl_path)?;\n    Ok(tmpl_path)\n}\n\nfn check_axis_separable(\n    graph: &GraphProto,\n    split_dim: usize,\n    slice_idx: usize,\n    model_opset: i64,\n) -> Result<()> {\n    let resolve_axis = |axis: i64| -> usize {\n        let ndim = graph\n            .input\n            .first()\n            .and_then(onnx_proto::shape_from_value_info)\n            .map(|s| s.len() as i64)\n            .unwrap_or(4);\n        if axis < 0 {\n            (ndim + axis) as usize\n        } else {\n            axis as usize\n        }\n    };\n\n    for node in &graph.node {\n        match node.op_type.as_str() {\n            \"Flatten\" => {\n                let axis = resolve_axis(onnx_proto::get_attribute_int(node, \"axis\").unwrap_or(1));\n                if split_dim < axis {\n                    return Err(crate::error::DsperseError::Slicer(format!(\n                        \"create_generic_dim_template: slice {slice_idx} Flatten axis \\\n                         {axis} > split_dim {split_dim}; split dimension falls in the merged leading group\"\n                    )));\n                }\n            }\n            \"Softmax\" | \"LogSoftmax\" => {\n                // ONNX Softmax / LogSoftmax default axis: opset >=\n                // 13 -> -1 (last axis), opset < 13 -> 1 (channel\n                // axis).\n                let default_axis: i64 = if model_opset >= 13 { -1 } else { 1 };\n                let resolved = resolve_axis(\n                    onnx_proto::get_attribute_int(node, \"axis\").unwrap_or(default_axis),\n                );\n                if resolved == split_dim {\n                    return Err(crate::error::DsperseError::Slicer(format!(\n                        \"create_generic_dim_template: slice {slice_idx} {} axis {resolved} \\\n                         equals split_dim {split_dim}; normalization spans the split dimension\",\n                        node.op_type\n                    )));\n                }\n            }\n            \"LayerNormalization\" => {\n                let resolved =\n                    resolve_axis(onnx_proto::get_attribute_int(node, \"axis\").unwrap_or(-1));\n                if resolved <= split_dim {\n                    return Err(crate::error::DsperseError::Slicer(format!(\n                        \"create_generic_dim_template: slice {slice_idx} LayerNormalization axis \\\n                         {resolved} <= split_dim {split_dim}; normalization spans the split dimension\",\n                    )));\n                }\n            }\n            \"BatchNormalization\" if split_dim == 0 => {\n                return Err(crate::error::DsperseError::Slicer(format!(\n                    \"create_generic_dim_template: slice {slice_idx} BatchNormalization requires \\\n                     full batch statistics; cannot split at dim 0\"\n                )));\n            }\n            _ => {}\n        }\n    }\n    Ok(())\n}\n\nfn create_generic_dim_template(\n    model: &ModelProto,\n    graph: &GraphProto,\n    info: &crate::schema::tiling::DimSplitInfo,\n    output_dir: &Path,\n    traced_shapes: Option<&HashMap<String, Vec<i64>>>,\n) -> Result<std::path::PathBuf> {\n    if info.elements_per_group == 0 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_generic_dim_template: slice {} elements_per_group is 0\",\n            info.slice_idx\n        )));\n    }\n\n    check_axis_separable(graph, info.split_dim, info.slice_idx, model_opset(model))?;\n\n    // Rewrite the template so the split axis carries elements_per_group\n    // instead of the full dim_size.  The runner only ever feeds a single\n    // group's worth of activations to the compiled circuit, so the\n    // *compile* cost should match the per-group cost rather than the\n    // whole-slice cost.  Catalog reuse is preserved at per-group\n    // granularity: any two slices that share (split_dim, epg, surrounding\n    // op shapes) hash identically.\n    //\n    // The strategy is: rewrite only the boundary shapes (graph inputs +\n    // shape-input initializers consumed by Reshape / Expand / Tile /\n    // ConstantOfShape) and a fresh shape inference pass derives every\n    // intermediate value_info from those.  Per-feature initializers\n    // (gamma, beta, weights) are never touched, and there are no ad-hoc\n    // cases for individual op patterns -- the rule is \"rewrite the\n    // boundary, let inference do the rest\".\n    let mut tmpl_model = model.clone();\n    let tmpl_graph = tmpl_model.graph.as_mut().ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"create_generic_dim_template: cloned model has no graph\".into(),\n        )\n    })?;\n    let dim_size = info.dim_size as i64;\n    let epg = info.elements_per_group as i64;\n    let split_dim = info.split_dim;\n\n    // 1. Decide which graph inputs must be rewritten at split_dim.\n    //\n    //    The runner always slices every cached tensor whose shape has\n    //    dim_size at split_dim, so the *compile-time* template needs\n    //    every such input declared with epg, otherwise jstprove's\n    //    type checker rejects the op (e.g. Mul broadcast 150 vs 300,\n    //    or MatMul A.K vs B.K mismatch).  But for ops where two\n    //    inputs reference dim_size at the *same* split_dim with\n    //    different semantic meanings (the canonical case is the\n    //    second attention MatMul: attn[B,H,M,N] @ V[B,H,N,D] with\n    //    M == N at split_dim=2) blanket rewriting both inputs\n    //    produces a real mismatch.\n    //\n    //    Heuristic:\n    //      * Elementwise / broadcast ops (Add, Sub, Mul, Div, Pow, Min,\n    //        Max, Where, Equal, Greater, Less): rewrite every input\n    //        whose shape has dim_size at split_dim.  All inputs share\n    //        a logical broadcast axis, so all must shrink together.\n    //      * MatMul / Gemm: rewrite only `info.input_name`.  The other\n    //        operand's split_dim is a contraction axis; touching it\n    //        produces an inner-dim mismatch.\n    //      * Everything else (the single-op slices we get after\n    //        isolate_expensive_ops): rewrite only `info.input_name`,\n    //        which is the safe default for ops with one primary\n    //        activation and a handful of scalar / per-feature\n    //        initializer inputs.\n    let elementwise_ops: HashSet<&str> = [\n        \"Add\", \"Sub\", \"Mul\", \"Div\", \"Pow\", \"Min\", \"Max\", \"Where\", \"Equal\", \"Greater\", \"Less\",\n    ]\n    .into_iter()\n    .collect();\n    let rewrite_all_matching = tmpl_graph\n        .node\n        .iter()\n        .all(|n| elementwise_ops.contains(n.op_type.as_str()));\n\n    let rewrite_input_at_split_dim = |vi: &mut super::onnx_proto::ValueInfoProto| {\n        if let Some(t) = vi.r#type.as_mut()\n            && let Some(super::onnx_proto::onnx::type_proto::Value::TensorType(tt)) =\n                t.value.as_mut()\n            && let Some(shape) = tt.shape.as_mut()\n            && let Some(d) = shape.dim.get_mut(split_dim)\n            && let Some(super::onnx_proto::onnx::tensor_shape_proto::dimension::Value::DimValue(v)) =\n                d.value.as_mut()\n            && *v == dim_size\n        {\n            *v = epg;\n        }\n    };\n\n    if rewrite_all_matching {\n        for vi in tmpl_graph\n            .input\n            .iter_mut()\n            .chain(tmpl_graph.output.iter_mut())\n        {\n            rewrite_input_at_split_dim(vi);\n        }\n    } else {\n        for vi in tmpl_graph\n            .input\n            .iter_mut()\n            .filter(|vi| vi.name == info.input_name)\n            .chain(\n                tmpl_graph\n                    .output\n                    .iter_mut()\n                    .filter(|vi| vi.name == info.output_name),\n            )\n        {\n            rewrite_input_at_split_dim(vi);\n        }\n    }\n\n    // 2. Rewrite shape-input initializers (Reshape / Expand / Tile /\n    //    ConstantOfShape).  These are explicit shape descriptors; if\n    //    the input shape changes their dim_size entry must change too.\n    let shape_input_initializers: HashSet<String> = tmpl_graph\n        .node\n        .iter()\n        .filter_map(|n| match n.op_type.as_str() {\n            \"Reshape\" | \"Expand\" | \"Tile\" => n.input.get(1).cloned(),\n            \"ConstantOfShape\" => n.input.first().cloned(),\n            _ => None,\n        })\n        .filter(|name| !name.is_empty())\n        .collect();\n    for init in &mut tmpl_graph.initializer {\n        if init.data_type == TensorProto::INT64 && shape_input_initializers.contains(&init.name) {\n            // ONNX TensorProto INT64 payloads can live in either\n            // int64_data (typed field) or raw_data (little-endian\n            // i64 byte stream); larger constants tend to use\n            // raw_data.  Patch both representations.\n            for v in &mut init.int64_data {\n                if *v == dim_size {\n                    *v = epg;\n                }\n            }\n            if !init.raw_data.is_empty() && init.raw_data.len() % 8 == 0 {\n                let mut buf: Vec<i64> = init\n                    .raw_data\n                    .chunks_exact(8)\n                    .map(|c| i64::from_le_bytes(c.try_into().unwrap()))\n                    .collect();\n                let mut changed = false;\n                for v in &mut buf {\n                    if *v == dim_size {\n                        *v = epg;\n                        changed = true;\n                    }\n                }\n                if changed {\n                    let mut new_raw = Vec::with_capacity(buf.len() * 8);\n                    for v in &buf {\n                        new_raw.extend_from_slice(&v.to_le_bytes());\n                    }\n                    init.raw_data = new_raw;\n                }\n            }\n        }\n    }\n\n    // 3. Drop every intermediate value_info; it will be re-derived.\n    tmpl_graph.value_info.clear();\n\n    let _ = traced_shapes; // intentionally unused: we re-trace after rewriting.\n\n    let tmpl_path = output_dir.join(\"dim_template.onnx\");\n    onnx_proto::save_model(&tmpl_model, &tmpl_path)?;\n\n    // 4. Re-run shape inference on the rewritten template and inject\n    //    the derived shapes back as value_info.  This replaces the old\n    //    ad-hoc per-op rewrites (which had to special-case every shape\n    //    op).  If re-trace fails the template is uncompilable -- the\n    //    circuit compiler downstream will see no value_info for the\n    //    intermediate tensors and produce hard-to-diagnose shape\n    //    errors at compile time.  Refuse to emit the template instead.\n    let trace = super::trace::fold_and_trace_via_tract(&tmpl_path, &tmpl_model).map_err(\n        |e| {\n            crate::error::DsperseError::Slicer(format!(\n                \"create_generic_dim_template: slice {} re-trace failed (template input shape {:?}, split_dim {}): {e}\",\n                info.slice_idx, info.input_name, split_dim\n            ))\n        },\n    )?;\n    {\n        let mut model_after = onnx_proto::load_model(&tmpl_path)?;\n        if let Some(graph_after) = model_after.graph.as_mut() {\n            let existing: HashSet<String> = graph_after\n                .input\n                .iter()\n                .chain(graph_after.output.iter())\n                .chain(graph_after.value_info.iter())\n                .map(|vi| vi.name.clone())\n                .collect();\n            let init_names: HashSet<&str> = graph_after\n                .initializer\n                .iter()\n                .map(|i| i.name.as_str())\n                .collect();\n            for node in &graph_after.node {\n                for out_name in &node.output {\n                    if out_name.is_empty()\n                        || existing.contains(out_name)\n                        || init_names.contains(out_name.as_str())\n                    {\n                        continue;\n                    }\n                    if let Some(shape) = trace.shapes.get(out_name) {\n                        let elem_type = trace\n                            .types\n                            .get(out_name)\n                            .copied()\n                            .unwrap_or(TensorProto::FLOAT);\n                        graph_after\n                            .value_info\n                            .push(onnx_proto::make_tensor_value_info(\n                                out_name, elem_type, shape,\n                            ));\n                    }\n                }\n            }\n            // Promote output_name to graph output if it now exists in\n            // value_info but not in graph.output.\n            if !graph_after\n                .output\n                .iter()\n                .any(|o| o.name == info.output_name)\n                && let Some(vi) = graph_after\n                    .value_info\n                    .iter()\n                    .find(|v| v.name == info.output_name)\n                    .cloned()\n            {\n                graph_after.output.push(vi);\n            }\n        }\n        onnx_proto::save_model(&model_after, &tmpl_path)?;\n    }\n\n    Ok(tmpl_path)\n}\n\npub fn create_elementwise_tile_slice(\n    model: &ModelProto,\n    segment_size: i64,\n    slice_idx: usize,\n    output_dir: &Path,\n) -> Result<TileSliceResult> {\n    if segment_size <= 0 {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"create_elementwise_tile_slice: segment_size must be > 0, got {segment_size}\"\n        )));\n    }\n    let graph = model.graph.as_ref().ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"create_elementwise_tile_slice: model.graph is None\".to_string(),\n        )\n    })?;\n    if graph.input.is_empty() {\n        return Err(crate::error::DsperseError::Slicer(\n            \"create_elementwise_tile_slice: no graph inputs\".to_string(),\n        ));\n    }\n    let out = graph.output.first().ok_or_else(|| {\n        crate::error::DsperseError::Slicer(\n            \"create_elementwise_tile_slice: no graph outputs\".to_string(),\n        )\n    })?;\n    let orig_output_name = &out.name;\n\n    let tile_shape: Vec<i64> = vec![segment_size];\n\n    let init_names: std::collections::HashSet<&str> =\n        graph.initializer.iter().map(|i| i.name.as_str()).collect();\n\n    let mut orig_to_tile: Vec<(String, String)> = Vec::with_capacity(graph.input.len());\n    let mut tile_inputs = Vec::with_capacity(graph.input.len());\n    let mut tile_idx = 0usize;\n    for inp in &graph.input {\n        let inp_shape = onnx_proto::shape_from_value_info(inp);\n        let is_broadcast = init_names.contains(inp.name.as_str())\n            || inp_shape\n                .as_ref()\n                .is_some_and(|s| s.iter().product::<i64>() < segment_size);\n        if is_broadcast {\n            tile_inputs.push(inp.clone());\n        } else {\n            let tile_name = format!(\"tile_in_{tile_idx}\");\n            tile_idx += 1;\n            tile_inputs.push(onnx_proto::make_tensor_value_info(\n                &tile_name,\n                onnx_proto::elem_type_from_value_info(inp).unwrap_or(TensorProto::FLOAT),\n                &tile_shape,\n            ));\n            orig_to_tile.push((inp.name.clone(), tile_name));\n        }\n    }\n    if tile_idx == 1\n        && let Some((_, tile_name)) = orig_to_tile.first_mut()\n    {\n        let old = tile_name.clone();\n        *tile_name = \"tile_in\".to_string();\n        for ti in &mut tile_inputs {\n            if ti.name == old {\n                ti.name = \"tile_in\".to_string();\n            }\n        }\n    }\n\n    let y = onnx_proto::make_tensor_value_info(\"tile_out\", TensorProto::FLOAT, &tile_shape);\n\n    let initializers: Vec<_> = graph.initializer.to_vec();\n\n    let input_remap: std::collections::HashMap<&str, &str> = orig_to_tile\n        .iter()\n        .map(|(k, v)| (k.as_str(), v.as_str()))\n        .collect();\n\n    let mut nodes = Vec::new();\n    for orig_node in &graph.node {\n        let new_inputs: Vec<String> = orig_node\n            .input\n            .iter()\n            .map(|name| {\n                input_remap\n                    .get(name.as_str())\n                    .map(|s| (*s).to_string())\n                    .unwrap_or_else(|| name.clone())\n            })\n            .collect();\n        let produces_output = orig_node.output.contains(orig_output_name);\n        let new_outputs = if produces_output {\n            orig_node\n                .output\n                .iter()\n                .map(|o| {\n                    if o == orig_output_name {\n                        \"tile_out\".to_string()\n                    } else {\n                        o.clone()\n                    }\n                })\n                .collect()\n        } else {\n            orig_node.output.clone()\n        };\n\n        nodes.push(NodeProto {\n            op_type: orig_node.op_type.clone(),\n            input: new_inputs,\n            output: new_outputs,\n            attribute: orig_node.attribute.clone(),\n            name: String::new(),\n            domain: String::new(),\n            doc_string: String::new(),\n            overload: String::new(),\n            metadata_props: vec![],\n            device_configurations: vec![],\n        });\n    }\n\n    let tile_graph = onnx_proto::make_graph(\n        &format!(\"tile_{slice_idx}\"),\n        nodes,\n        tile_inputs,\n        vec![y],\n        initializers,\n    );\n    let tile_model = onnx_proto::make_model(tile_graph, model_opset(model));\n\n    let tiles_dir = output_dir.join(\"tiles\");\n    std::fs::create_dir_all(&tiles_dir)\n        .map_err(|e| crate::error::DsperseError::io(e, &tiles_dir))?;\n    let onnx_path = tiles_dir.join(\"tile.onnx\");\n    onnx_proto::save_model(&tile_model, &onnx_path)?;\n\n    Ok(TileSliceResult {\n        path: format!(\"slice_{slice_idx}/payload/tiles/tile.onnx\"),\n        conv_out: [segment_size, 1],\n    })\n}\n\n#[derive(Debug)]\npub struct TileSliceResult {\n    pub path: String,\n    pub conv_out: [i64; 2],\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn halo_symmetric_pads() {\n        assert_eq!(compute_halo_size([1, 1, 1, 1]), Some([1, 1, 1, 1]));\n    }\n\n    #[test]\n    fn halo_asymmetric_pads() {\n        assert_eq!(compute_halo_size([6, 6, 7, 7]), Some([6, 6, 7, 7]));\n    }\n\n    #[test]\n    fn halo_zero_pads() {\n        assert_eq!(compute_halo_size([0, 0, 0, 0]), Some([0, 0, 0, 0]));\n    }\n\n    #[test]\n    fn halo_negative_pads_rejected() {\n        assert_eq!(compute_halo_size([-1, 0, 0, 0]), None);\n    }\n\n    #[test]\n    fn halo_mixed_pads() {\n        assert_eq!(compute_halo_size([1, 2, 1, 2]), Some([1, 2, 1, 2]));\n    }\n\n    #[test]\n    fn min_tile_3x3_no_dilation() {\n        assert_eq!(compute_min_spatial_tile([3, 3], [1, 1]), Some(4));\n    }\n\n    #[test]\n    fn min_tile_5x5_no_dilation() {\n        assert_eq!(compute_min_spatial_tile([5, 5], [1, 1]), Some(6));\n    }\n\n    #[test]\n    fn min_tile_3x3_dilation_2() {\n        let eff = (3 - 1) * 2 + 1;\n        assert_eq!(compute_min_spatial_tile([3, 3], [2, 2]), Some(eff + 1));\n    }\n\n    #[test]\n    fn min_tile_1x1() {\n        assert_eq!(compute_min_spatial_tile([1, 1], [1, 1]), Some(2));\n    }\n\n    #[test]\n    fn optimal_tile_exact_divisor() {\n        assert_eq!(find_optimal_tile_size(64, 32, 4, 1), Some(32));\n    }\n\n    #[test]\n    fn optimal_tile_no_exact_divisor_falls_back() {\n        assert_eq!(find_optimal_tile_size(64, 30, 4, 1), Some(16));\n    }\n\n    #[test]\n    fn optimal_tile_target_equals_spatial() {\n        assert_eq!(find_optimal_tile_size(32, 32, 4, 1), None);\n    }\n\n    #[test]\n    fn optimal_tile_min_exceeds_target() {\n        assert_eq!(find_optimal_tile_size(64, 3, 4, 1), None);\n    }\n\n    #[test]\n    fn optimal_tile_stride_constraint() {\n        assert_eq!(find_optimal_tile_size(64, 32, 4, 2), Some(32));\n        assert_eq!(find_optimal_tile_size(12, 8, 2, 4), Some(4));\n    }\n\n    #[test]\n    fn optimal_tile_no_valid_stride_divisor() {\n        assert_eq!(find_optimal_tile_size(15, 10, 2, 4), None);\n    }\n\n    #[test]\n    fn checked_dim_product_normal() {\n        assert_eq!(checked_dim_product(&[2, 3, 4]).unwrap(), 24);\n    }\n\n    #[test]\n    fn checked_dim_product_empty() {\n        assert_eq!(checked_dim_product(&[]).unwrap(), 1);\n    }\n\n    #[test]\n    fn checked_dim_product_overflow() {\n        assert!(checked_dim_product(&[usize::MAX, 2]).is_err());\n    }\n\n    #[test]\n    fn checked_dim_product_single() {\n        assert_eq!(checked_dim_product(&[42]).unwrap(), 42);\n    }\n\n    #[test]\n    fn slice_weights_basic() {\n        let weights = WeightInfo {\n            data: (0..24).map(|i| i as f32).collect(),\n            dims: vec![2, 3, 2, 2],\n        };\n        let sliced = slice_weights(&weights, 0, 2).unwrap();\n        assert_eq!(sliced.dims, vec![2, 2, 2, 2]);\n        assert_eq!(sliced.data.len(), 16);\n        assert_eq!(sliced.data[0], 0.0);\n        assert_eq!(sliced.data[1], 1.0);\n        assert_eq!(sliced.data[2], 2.0);\n        assert_eq!(sliced.data[3], 3.0);\n    }\n\n    #[test]\n    fn slice_weights_single_channel() {\n        let weights = WeightInfo {\n            data: (0..24).map(|i| i as f32).collect(),\n            dims: vec![2, 3, 2, 2],\n        };\n        let sliced = slice_weights(&weights, 1, 2).unwrap();\n        assert_eq!(sliced.dims, vec![2, 1, 2, 2]);\n        assert_eq!(sliced.data.len(), 8);\n    }\n\n    #[test]\n    fn slice_weights_start_ge_end() {\n        let weights = WeightInfo {\n            data: vec![1.0; 16],\n            dims: vec![1, 4, 2, 2],\n        };\n        assert!(slice_weights(&weights, 3, 2).is_err());\n    }\n\n    #[test]\n    fn slice_weights_end_exceeds_c_in() {\n        let weights = WeightInfo {\n            data: vec![1.0; 16],\n            dims: vec![1, 4, 2, 2],\n        };\n        assert!(slice_weights(&weights, 0, 5).is_err());\n    }\n\n    #[test]\n    fn slice_weights_insufficient_dims() {\n        let weights = WeightInfo {\n            data: vec![1.0; 6],\n            dims: vec![2, 3],\n        };\n        assert!(slice_weights(&weights, 0, 1).is_err());\n    }\n\n    #[test]\n    fn slice_weights_data_length_mismatch() {\n        let weights = WeightInfo {\n            data: vec![1.0; 10],\n            dims: vec![2, 3, 2, 2],\n        };\n        assert!(slice_weights(&weights, 0, 2).is_err());\n    }\n\n    #[test]\n    fn elementwise_ops_recognized() {\n        assert!(is_elementwise(\"Relu\"));\n        assert!(is_elementwise(\"Sigmoid\"));\n        assert!(is_elementwise(\"Add\"));\n        assert!(is_elementwise(\"Mul\"));\n    }\n\n    #[test]\n    fn non_elementwise_ops_rejected() {\n        assert!(!is_elementwise(\"Conv\"));\n        assert!(!is_elementwise(\"MaxPool\"));\n        assert!(!is_elementwise(\"Gemm\"));\n        assert!(!is_elementwise(\"BatchNormalization\"));\n    }\n\n    #[test]\n    fn spatial_tile_config_already_fits() {\n        let (tile, reason) = calculate_spatial_tile_config(3, 4, 4, 64, 4, 1);\n        assert!(tile.is_none());\n        assert_eq!(reason, Some(\"already_fits\"));\n    }\n\n    #[test]\n    fn spatial_tile_config_min_tile_too_large() {\n        let (tile, reason) = calculate_spatial_tile_config(64, 8, 8, 100, 8, 1);\n        assert!(tile.is_none());\n        assert_eq!(reason, Some(\"min_tile_too_large\"));\n    }\n\n    #[test]\n    fn spatial_tile_config_finds_tile() {\n        let (tile, reason) = calculate_spatial_tile_config(3, 64, 64, 3 * 32 * 32, 4, 1);\n        assert!(tile.is_some());\n        assert!(reason.is_none());\n        let t = tile.unwrap();\n        assert!(64 % t == 0);\n        assert!(t >= 4);\n    }\n\n    #[test]\n    fn channel_split_config_basic() {\n        let result = calculate_channel_split_config(64, 32, 4, 4, 32);\n        assert!(result.is_some());\n        let (num_groups, cpg) = result.unwrap();\n        assert!(num_groups > 1);\n        assert!(cpg > 0);\n        assert!(cpg * (num_groups - 1) < 64);\n    }\n\n    #[test]\n    fn channel_split_config_zero_dims() {\n        assert!(calculate_channel_split_config(64, 32, 0, 4, 32).is_none());\n        assert!(calculate_channel_split_config(64, 32, 4, 0, 32).is_none());\n    }\n\n    #[test]\n    fn channel_split_config_fits_without_splitting() {\n        assert!(calculate_channel_split_config(4, 32, 2, 2, 100).is_none());\n    }\n\n    #[test]\n    fn detect_tiling_none_without_tile_size() {\n        let model = onnx_proto::make_model(\n            onnx_proto::make_graph(\"test\", vec![], vec![], vec![], vec![]),\n            13,\n        );\n        assert!(detect_tiling_needs(&model, None).is_none());\n    }\n\n    #[test]\n    fn detect_tiling_none_empty_graph() {\n        let model = onnx_proto::make_model(\n            onnx_proto::make_graph(\"test\", vec![], vec![], vec![], vec![]),\n            13,\n        );\n        assert!(detect_tiling_needs(&model, Some(1024)).is_none());\n    }\n\n    #[test]\n    fn effective_kernel_overflow() {\n        assert_eq!(effective_kernel([i64::MAX, 1], [2, 1]), None);\n        assert_eq!(effective_kernel([1, i64::MAX], [1, 2]), None);\n    }\n\n    #[test]\n    fn effective_kernel_sub_underflow() {\n        assert_eq!(effective_kernel([i64::MIN, 3], [1, 1]), None);\n    }\n\n    #[test]\n    fn effective_kernel_valid() {\n        assert_eq!(effective_kernel([3, 3], [1, 1]), Some([3, 3]));\n        assert_eq!(effective_kernel([3, 3], [2, 2]), Some([5, 5]));\n        assert_eq!(effective_kernel([1, 1], [1, 1]), Some([1, 1]));\n    }\n\n    #[test]\n    fn conv_output_hw_zero_stride() {\n        assert_eq!(\n            conv_output_hw(8, 8, [0, 0, 0, 0], [3, 3], [1, 1], [0, 1]),\n            None\n        );\n        assert_eq!(\n            conv_output_hw(8, 8, [0, 0, 0, 0], [3, 3], [1, 1], [1, 0]),\n            None\n        );\n    }\n\n    #[test]\n    fn conv_output_hw_kernel_exceeds_input() {\n        assert_eq!(\n            conv_output_hw(2, 2, [0, 0, 0, 0], [5, 5], [1, 1], [1, 1]),\n            None\n        );\n    }\n\n    #[test]\n    fn conv_output_hw_overflow_pads() {\n        assert_eq!(\n            conv_output_hw(i64::MAX, 8, [1, 0, 0, 0], [3, 3], [1, 1], [1, 1]),\n            None\n        );\n    }\n\n    #[test]\n    fn conv_output_hw_valid() {\n        assert_eq!(\n            conv_output_hw(8, 8, [1, 1, 1, 1], [3, 3], [1, 1], [1, 1]),\n            Some((8, 8))\n        );\n        assert_eq!(\n            conv_output_hw(8, 8, [0, 0, 0, 0], [3, 3], [1, 1], [2, 2]),\n            Some((3, 3))\n        );\n    }\n\n    #[test]\n    fn compute_halo_size_negative_rejected() {\n        assert_eq!(compute_halo_size([0, 0, -1, 0]), None);\n    }\n\n    #[test]\n    fn compute_min_spatial_tile_overflow() {\n        assert_eq!(compute_min_spatial_tile([i64::MAX, 1], [2, 1]), None);\n    }\n\n    #[test]\n    fn slice_weights_full_range_is_identity() {\n        let data: Vec<f32> = (0..48).map(|i| i as f32).collect();\n        let weights = WeightInfo {\n            data: data.clone(),\n            dims: vec![2, 3, 2, 4],\n        };\n        let sliced = slice_weights(&weights, 0, 3).unwrap();\n        assert_eq!(sliced.dims, vec![2, 3, 2, 4]);\n        assert_eq!(sliced.data, data);\n    }\n\n    #[test]\n    fn detect_dim_split_gemm_trans_b() {\n        use super::onnx_proto::{NodeProto, make_attribute_int};\n\n        // Unbiased Gemm with transB=1. Biased Gemm is rejected upstream by\n        // create_matmul_dim_template, so the detector now skips it as well.\n        let node = NodeProto {\n            op_type: \"Gemm\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"output\".to_string()],\n            attribute: vec![make_attribute_int(\"transB\", 1)],\n            ..Default::default()\n        };\n\n        let mut shapes = HashMap::new();\n        shapes.insert(\"input\".to_string(), vec![4, 145, 384]);\n        shapes.insert(\"weight\".to_string(), vec![1536, 384]);\n        shapes.insert(\"output\".to_string(), vec![4, 145, 1536]);\n\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n\n        let detection = detect_dim_split(&[node], &shapes, &init_names, 17);\n        assert!(detection.is_some());\n        let d = detection.unwrap();\n        assert_eq!(d.split_dim, 0);\n        assert_eq!(d.dim_size, 580);\n        assert_eq!(d.num_groups, 580);\n        assert_eq!(d.elements_per_group, 1);\n        assert_eq!(d.k_dim, 384);\n        assert_eq!(d.n_dim, 1536);\n        assert!(matches!(d.split_kind, DimSplitKind::MatMulOutputDim));\n    }\n\n    #[test]\n    fn detect_dim_split_matmul_no_trans() {\n        let node = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n\n        let mut shapes = HashMap::new();\n        shapes.insert(\"input\".to_string(), vec![4, 145, 384]);\n        shapes.insert(\"weight\".to_string(), vec![384, 1536]);\n        shapes.insert(\"output\".to_string(), vec![4, 145, 1536]);\n\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n\n        let detection = detect_dim_split(&[node], &shapes, &init_names, 17);\n        assert!(detection.is_some());\n        let d = detection.unwrap();\n        assert_eq!(d.split_dim, 0);\n        assert_eq!(d.dim_size, 580);\n        assert_eq!(d.num_groups, 580);\n        assert_eq!(d.elements_per_group, 1);\n        assert_eq!(d.k_dim, 384);\n        assert_eq!(d.n_dim, 1536);\n        assert!(matches!(d.split_kind, DimSplitKind::MatMulOutputDim));\n    }\n\n    #[test]\n    fn detect_dim_split_k_chunks_saturate_budget() {\n        // k_dim=10, n_dim=300_000: row_cost=6M. Naive k_chunks=ceil(6M/2M)=3\n        // yields chunk_size=ceil(10/3)=4 -> per-chunk=4*300_000*2=2.4M > 2M\n        // (MAX_ESTIMATED_CONSTRAINTS). Loop bumps k_chunks to 4 giving\n        // chunk_size=3 -> per-chunk=1.8M which fits.\n        let node = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        shapes.insert(\"input\".to_string(), vec![4, 10]);\n        shapes.insert(\"weight\".to_string(), vec![10, 300_000]);\n        shapes.insert(\"output\".to_string(), vec![4, 300_000]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n\n        let d = detect_dim_split(&[node], &shapes, &init_names, 17).unwrap();\n        assert_eq!(d.k_dim, 10);\n        assert_eq!(d.n_dim, 300_000);\n        let chunk_size = d.k_dim.div_ceil(d.k_chunks);\n        assert!(\n            chunk_size * d.n_dim * 2 <= MAX_ESTIMATED_CONSTRAINTS as usize,\n            \"per-chunk cost {} exceeds MAX {}\",\n            chunk_size * d.n_dim * 2,\n            MAX_ESTIMATED_CONSTRAINTS\n        );\n    }\n\n    #[test]\n    fn detect_dim_split_single_row_with_k_chunking() {\n        // total_rows=1 but k*n*2 > MAX: still detect, K-chunk it.\n        let node = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        shapes.insert(\"input\".to_string(), vec![1, 2048]);\n        shapes.insert(\"weight\".to_string(), vec![2048, 2048]);\n        shapes.insert(\"output\".to_string(), vec![1, 2048]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n\n        let d = detect_dim_split(&[node], &shapes, &init_names, 17).unwrap();\n        assert_eq!(d.dim_size, 1);\n        assert_eq!(d.num_groups, 1);\n        assert!(d.k_chunks > 1, \"expected K-chunking for single row\");\n        let chunk_size = d.k_dim.div_ceil(d.k_chunks);\n        assert!(chunk_size * d.n_dim * 2 <= MAX_ESTIMATED_CONSTRAINTS as usize);\n    }\n\n    #[test]\n    fn detect_dim_split_skips_single_row_single_chunk() {\n        // total_rows=1 and k*n*2 <= MAX: nothing to split via MatMul path.\n        // The slice is still over budget (forced via a second MatMul), but\n        // dim-split should decline and let the caller fall through.\n        let node1 = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"w1\".to_string()],\n            output: vec![\"mid\".to_string()],\n            ..Default::default()\n        };\n        let node2 = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"mid\".to_string(), \"w2\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        shapes.insert(\"input\".to_string(), vec![1, 64]);\n        shapes.insert(\"w1\".to_string(), vec![64, 64]);\n        shapes.insert(\"mid\".to_string(), vec![1, 64]);\n        shapes.insert(\"w2\".to_string(), vec![64, 64]);\n        shapes.insert(\"output\".to_string(), vec![1, 64]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"w1\".to_string());\n        init_names.insert(\"w2\".to_string());\n        // Tiny per-op cost; slice estimate stays under MAX so detect_dim_split\n        // returns None at the outer gate, which is what we want for a\n        // single-row single-chunk MatMul.\n        assert!(detect_dim_split(&[node1, node2], &shapes, &init_names, 17).is_none());\n    }\n\n    #[test]\n    fn detect_dim_split_declines_infeasible_n() {\n        // n_dim * 2 > MAX means even k_chunks == k_dim (chunk_size = 1)\n        // cannot fit inside the per-chunk budget, so the MatMul branch must\n        // decline. Use batch=1 so the BatchDim fallback path is not taken.\n        let node = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        // n_dim = 1_500_000 -> n*2 = 3_000_000 > MAX (2_000_000)\n        shapes.insert(\"input\".to_string(), vec![1, 4]);\n        shapes.insert(\"weight\".to_string(), vec![4, 1_500_000]);\n        shapes.insert(\"output\".to_string(), vec![1, 1_500_000]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n\n        let got = detect_dim_split(&[node], &shapes, &init_names, 17);\n        assert!(\n            got.as_ref()\n                .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)),\n            \"expected MatMul dim-split to decline, got {got:?}\"\n        );\n    }\n\n    #[test]\n    fn detect_dim_split_skips_non_terminal_matmul() {\n        // MatMul output is consumed by a later Add inside the same slice.\n        // The dim-split runner only writes MatMul output to the cache, so\n        // the Add would never run; detection must decline this MatMul and\n        // either pick a later terminal MatMul or fall through.\n        let matmul = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"mid\".to_string()],\n            ..Default::default()\n        };\n        let add = NodeProto {\n            op_type: \"Add\".to_string(),\n            input: vec![\"mid\".to_string(), \"bias\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        shapes.insert(\"input\".to_string(), vec![1, 145, 384]);\n        shapes.insert(\"weight\".to_string(), vec![384, 1536]);\n        shapes.insert(\"bias\".to_string(), vec![1536]);\n        shapes.insert(\"mid\".to_string(), vec![1, 145, 1536]);\n        shapes.insert(\"output\".to_string(), vec![1, 145, 1536]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n        init_names.insert(\"bias\".to_string());\n\n        let got = detect_dim_split(&[matmul, add], &shapes, &init_names, 17);\n        assert!(\n            got.as_ref()\n                .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)),\n            \"expected non-terminal MatMul to be declined, got {got:?}\"\n        );\n    }\n\n    #[test]\n    fn detect_dim_split_picks_terminal_matmul_after_consumed_one() {\n        // First MatMul feeds a second MatMul; only the second is terminal,\n        // so detection must skip the first and select the second when both\n        // are otherwise eligible.\n        let m1 = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"w1\".to_string()],\n            output: vec![\"mid\".to_string()],\n            ..Default::default()\n        };\n        let m2 = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"mid\".to_string(), \"w2\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        shapes.insert(\"input\".to_string(), vec![4, 145, 384]);\n        shapes.insert(\"w1\".to_string(), vec![384, 1536]);\n        shapes.insert(\"mid\".to_string(), vec![4, 145, 1536]);\n        shapes.insert(\"w2\".to_string(), vec![1536, 384]);\n        shapes.insert(\"output\".to_string(), vec![4, 145, 384]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"w1\".to_string());\n        init_names.insert(\"w2\".to_string());\n\n        let d = detect_dim_split(&[m1, m2], &shapes, &init_names, 17).unwrap();\n        assert_eq!(d.weight_name.as_deref(), Some(\"w2\"));\n        assert_eq!(d.output_name, \"output\");\n        assert_eq!(d.k_dim, 1536);\n        assert_eq!(d.n_dim, 384);\n    }\n\n    #[test]\n    fn detect_dim_split_skips_gemm_trans_a() {\n        use super::onnx_proto::make_attribute_int;\n        let node = NodeProto {\n            op_type: \"Gemm\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"output\".to_string()],\n            attribute: vec![make_attribute_int(\"transA\", 1)],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        // Use batch=1 so the BatchDim fallback path does not mask the\n        // MatMul-branch decline we want to assert.\n        shapes.insert(\"input\".to_string(), vec![1, 384, 145]);\n        shapes.insert(\"weight\".to_string(), vec![384, 1536]);\n        shapes.insert(\"output\".to_string(), vec![1, 145, 1536]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n\n        let got = detect_dim_split(&[node], &shapes, &init_names, 17);\n        assert!(\n            got.as_ref()\n                .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)),\n            \"expected Gemm transA=1 MatMul decline, got {got:?}\"\n        );\n    }\n\n    #[test]\n    fn detect_dim_split_skips_gemm_with_bias() {\n        use super::onnx_proto::make_attribute_int;\n        let node = NodeProto {\n            op_type: \"Gemm\".to_string(),\n            input: vec![\n                \"input\".to_string(),\n                \"weight\".to_string(),\n                \"bias\".to_string(),\n            ],\n            output: vec![\"output\".to_string()],\n            attribute: vec![make_attribute_int(\"transB\", 1)],\n            ..Default::default()\n        };\n        let mut shapes = HashMap::new();\n        // Use batch=1 so the BatchDim fallback path does not mask the\n        // MatMul-branch decline we want to assert.\n        shapes.insert(\"input\".to_string(), vec![1, 145, 384]);\n        shapes.insert(\"weight\".to_string(), vec![1536, 384]);\n        shapes.insert(\"bias\".to_string(), vec![1536]);\n        shapes.insert(\"output\".to_string(), vec![1, 145, 1536]);\n        let mut init_names = HashSet::new();\n        init_names.insert(\"weight\".to_string());\n        init_names.insert(\"bias\".to_string());\n\n        // Detector should decline the MatMul branch since the template\n        // builder cannot handle biased Gemm, forcing fall-through.\n        let got = detect_dim_split(&[node], &shapes, &init_names, 17);\n        assert!(\n            got.as_ref()\n                .is_none_or(|d| !matches!(d.split_kind, DimSplitKind::MatMulOutputDim)),\n            \"expected Gemm-with-bias MatMul decline, got {got:?}\"\n        );\n    }\n\n    #[test]\n    fn create_matmul_dim_template_uses_info_weight_name() {\n        // Graph has two MatMul nodes referencing different weights. The\n        // template builder must pick the node whose input is info.weight_name,\n        // not the first MatMul encountered.\n        let x = onnx_proto::make_tensor_value_info(\"input\", TensorProto::FLOAT, &[4, 64]);\n        let y = onnx_proto::make_tensor_value_info(\"output\", TensorProto::FLOAT, &[4, 2048]);\n\n        let w_small = onnx_proto::make_tensor(\n            \"w_small\",\n            TensorProto::FLOAT,\n            &[64, 64],\n            vec![0.0f32; 64 * 64],\n        );\n        let w_big = onnx_proto::make_tensor(\n            \"w_big\",\n            TensorProto::FLOAT,\n            &[64, 2048],\n            vec![0.0f32; 64 * 2048],\n        );\n\n        let n1 = onnx_proto::make_node(\n            \"MatMul\",\n            vec![\"input\".into(), \"w_small\".into()],\n            vec![\"mid\".into()],\n            vec![],\n        );\n        let n2 = onnx_proto::make_node(\n            \"MatMul\",\n            vec![\"mid\".into(), \"w_big\".into()],\n            vec![\"output\".into()],\n            vec![],\n        );\n\n        let graph = onnx_proto::make_graph(\n            \"two_matmul\",\n            vec![n1, n2],\n            vec![x],\n            vec![y],\n            vec![w_small, w_big],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n\n        let info = crate::schema::tiling::DimSplitInfo {\n            slice_idx: 0,\n            weight_name: Some(\"w_big\".to_string()),\n            input_name: \"mid\".to_string(),\n            output_name: \"output\".to_string(),\n            k_dim: 64,\n            n_dim: 2048,\n            k_chunks: 1,\n            ..Default::default()\n        };\n\n        let tmp = tempfile::tempdir().unwrap();\n        let tmpl_path = create_dim_split_template(&model, &info, tmp.path(), None).unwrap();\n        let tmpl_model = onnx_proto::load_model(&tmpl_path).unwrap();\n        let g = tmpl_model.graph.as_ref().unwrap();\n        let w = g.initializer.iter().find(|i| i.name == \"W\").unwrap();\n        // Template weight shape must reflect w_big (64, 2048), not w_small.\n        assert_eq!(w.dims, vec![64, 2048]);\n    }\n\n    #[test]\n    fn create_matmul_dim_template_disambiguates_shared_weight() {\n        // Two MatMul ops share the same weight initializer (e.g. tied\n        // weights). The template builder must select the op whose\n        // input/output names match info, not the first node that happens\n        // to reference the initializer.\n        let x = onnx_proto::make_tensor_value_info(\"input\", TensorProto::FLOAT, &[4, 64]);\n        let y_a = onnx_proto::make_tensor_value_info(\"out_a\", TensorProto::FLOAT, &[4, 32]);\n        let y_b = onnx_proto::make_tensor_value_info(\"out_b\", TensorProto::FLOAT, &[1, 32]);\n\n        let shared_w = onnx_proto::make_tensor(\n            \"tied_w\",\n            TensorProto::FLOAT,\n            &[64, 32],\n            vec![0.0f32; 64 * 32],\n        );\n\n        // First op: input -> tied_w -> out_a (shape [4, 32])\n        let n_a = onnx_proto::make_node(\n            \"MatMul\",\n            vec![\"input\".into(), \"tied_w\".into()],\n            vec![\"out_a\".into()],\n            vec![],\n        );\n        // Second op: alt_in -> tied_w -> out_b (shape [1, 32])\n        let alt_in = onnx_proto::make_tensor_value_info(\"alt_in\", TensorProto::FLOAT, &[1, 64]);\n        let n_b = onnx_proto::make_node(\n            \"MatMul\",\n            vec![\"alt_in\".into(), \"tied_w\".into()],\n            vec![\"out_b\".into()],\n            vec![],\n        );\n\n        let graph = onnx_proto::make_graph(\n            \"shared_weight\",\n            vec![n_a, n_b],\n            vec![x, alt_in],\n            vec![y_a, y_b],\n            vec![shared_w],\n        );\n        let model = onnx_proto::make_model(graph, 13);\n\n        // Target the second op explicitly via input_name/output_name.\n        let info = crate::schema::tiling::DimSplitInfo {\n            slice_idx: 0,\n            weight_name: Some(\"tied_w\".to_string()),\n            input_name: \"alt_in\".to_string(),\n            output_name: \"out_b\".to_string(),\n            k_dim: 64,\n            n_dim: 32,\n            k_chunks: 1,\n            ..Default::default()\n        };\n\n        let tmp = tempfile::tempdir().unwrap();\n        // Builder should succeed by binding the second op (the one whose\n        // IO matches info), even though the first op also references the\n        // same weight initializer.\n        let tmpl_path = create_dim_split_template(&model, &info, tmp.path(), None).unwrap();\n        let tmpl_model = onnx_proto::load_model(&tmpl_path).unwrap();\n        let g = tmpl_model.graph.as_ref().unwrap();\n        let w = g.initializer.iter().find(|i| i.name == \"W\").unwrap();\n        assert_eq!(w.dims, vec![64, 32]);\n    }\n\n    fn make_maxpool_node(\n        kernel: i64,\n        stride: i64,\n        pads: [i64; 4],\n        ceil_mode: Option<i64>,\n    ) -> NodeProto {\n        let mut attrs = vec![\n            onnx_proto::make_attribute_ints(\"kernel_shape\", &[kernel, kernel]),\n            onnx_proto::make_attribute_ints(\"strides\", &[stride, stride]),\n            onnx_proto::make_attribute_ints(\"pads\", &pads),\n        ];\n        if let Some(cm) = ceil_mode {\n            attrs.push(onnx_proto::make_attribute_int(\"ceil_mode\", cm));\n        }\n        onnx_proto::make_node(\n            \"MaxPool\",\n            vec![\"input\".into()],\n            vec![\"output\".into()],\n            attrs,\n        )\n    }\n\n    #[test]\n    fn pool_params_valid() {\n        let node = make_maxpool_node(2, 2, [0, 0, 0, 0], None);\n        let pp = PoolParams::from_node(&node, 0);\n        assert!(pp.is_some());\n        let pp = pp.unwrap();\n        assert_eq!(pp.kernel, [2, 2]);\n        assert_eq!(pp.stride, [2, 2]);\n    }\n\n    #[test]\n    fn pool_params_rejects_ceil_mode() {\n        let node = make_maxpool_node(2, 2, [0, 0, 0, 0], Some(1));\n        assert!(PoolParams::from_node(&node, 0).is_none());\n    }\n\n    #[test]\n    fn pool_params_accepts_ceil_mode_zero() {\n        let node = make_maxpool_node(2, 2, [0, 0, 0, 0], Some(0));\n        assert!(PoolParams::from_node(&node, 0).is_some());\n    }\n\n    #[test]\n    fn pool_params_rejects_auto_pad() {\n        let mut attrs = vec![\n            onnx_proto::make_attribute_ints(\"kernel_shape\", &[2, 2]),\n            onnx_proto::make_attribute_ints(\"strides\", &[2, 2]),\n        ];\n        attrs.push(onnx_proto::AttributeProto {\n            name: \"auto_pad\".into(),\n            s: b\"SAME_UPPER\".to_vec(),\n            ..Default::default()\n        });\n        let node = onnx_proto::make_node(\n            \"MaxPool\",\n            vec![\"input\".into()],\n            vec![\"output\".into()],\n            attrs,\n        );\n        assert!(PoolParams::from_node(&node, 0).is_none());\n    }\n\n    #[test]\n    fn pool_params_rejects_non_maxpool() {\n        let node = onnx_proto::make_node(\n            \"Conv\",\n            vec![\"input\".into()],\n            vec![\"output\".into()],\n            vec![onnx_proto::make_attribute_ints(\"kernel_shape\", &[3, 3])],\n        );\n        assert!(PoolParams::from_node(&node, 0).is_none());\n    }\n\n    fn make_elementwise_model(op: &str, shape: &[i64]) -> ModelProto {\n        let x = onnx_proto::make_tensor_value_info(\"input\", TensorProto::FLOAT, shape);\n        let y = onnx_proto::make_tensor_value_info(\"output\", TensorProto::FLOAT, shape);\n        let node = onnx_proto::make_node(op, vec![\"input\".into()], vec![\"output\".into()], vec![]);\n        let graph = onnx_proto::make_graph(\"test\", vec![node], vec![x], vec![y], vec![]);\n        onnx_proto::make_model(graph, 13)\n    }\n\n    #[test]\n    fn fixed_segments_too_small_returns_none() {\n        let model = make_elementwise_model(\"Relu\", &[1, 3, 8, 8]);\n        assert!(detect_elementwise_fixed_segments(model.graph.as_ref().unwrap()).is_none());\n    }\n\n    #[test]\n    fn fixed_segments_detects_large_tensor() {\n        let model = make_elementwise_model(\"Relu\", &[1, 16, 64, 64]);\n        let graph = model.graph.as_ref().unwrap();\n        let det = detect_elementwise_fixed_segments(graph);\n        assert!(det.is_some());\n        if let Some(TilingDetection::FixedSegment {\n            segment_size,\n            total_elements,\n            num_segments,\n            ..\n        }) = det\n        {\n            assert_eq!(total_elements, 16 * 64 * 64);\n            assert_eq!(segment_size, ELEMENTWISE_SEGMENT_SIZE);\n            assert_eq!(\n                num_segments,\n                (total_elements + segment_size - 1) / segment_size\n            );\n        } else {\n            panic!(\"expected FixedSegment variant\");\n        }\n    }\n\n    #[test]\n    fn fixed_segments_rejects_zero_dim() {\n        let model = make_elementwise_model(\"Relu\", &[1, 0, 64, 64]);\n        assert!(detect_elementwise_fixed_segments(model.graph.as_ref().unwrap()).is_none());\n    }\n\n    #[test]\n    fn fixed_segments_rejects_non_elementwise() {\n        let x = onnx_proto::make_tensor_value_info(\"input\", TensorProto::FLOAT, &[1, 16, 64, 64]);\n        let y = onnx_proto::make_tensor_value_info(\"output\", TensorProto::FLOAT, &[1, 16, 64, 64]);\n        let node = onnx_proto::make_node(\n            \"Softmax\",\n            vec![\"input\".into()],\n            vec![\"output\".into()],\n            vec![],\n        );\n        let graph = onnx_proto::make_graph(\"test\", vec![node], vec![x], vec![y], vec![]);\n        let model = onnx_proto::make_model(graph, 13);\n        assert!(detect_elementwise_fixed_segments(model.graph.as_ref().unwrap()).is_none());\n    }\n\n    #[test]\n    fn create_pool_tile_slice_valid() {\n        let x = onnx_proto::make_tensor_value_info(\"input\", TensorProto::FLOAT, &[1, 3, 64, 64]);\n        let y = onnx_proto::make_tensor_value_info(\"output\", TensorProto::FLOAT, &[1, 3, 32, 32]);\n        let node = make_maxpool_node(2, 2, [0, 0, 0, 0], None);\n        let graph = onnx_proto::make_graph(\"pool\", vec![node], vec![x], vec![y], vec![]);\n        let model = onnx_proto::make_model(graph, 13);\n        let tmp = tempfile::tempdir().unwrap();\n        let result = create_pool_tile_slice(&model, 16, 0, tmp.path());\n        assert!(result.is_ok());\n        let r = result.unwrap();\n        assert!(r.path.contains(\"tile.onnx\"));\n    }\n\n    #[test]\n    fn create_pool_tile_slice_rejects_zero_tile() {\n        let x = onnx_proto::make_tensor_value_info(\"input\", TensorProto::FLOAT, &[1, 3, 64, 64]);\n        let y = onnx_proto::make_tensor_value_info(\"output\", TensorProto::FLOAT, &[1, 3, 32, 32]);\n        let node = make_maxpool_node(2, 2, [0, 0, 0, 0], None);\n        let graph = onnx_proto::make_graph(\"pool\", vec![node], vec![x], vec![y], vec![]);\n        let model = onnx_proto::make_model(graph, 13);\n        let tmp = tempfile::tempdir().unwrap();\n        assert!(create_pool_tile_slice(&model, 0, 0, tmp.path()).is_err());\n    }\n\n    #[test]\n    fn create_pool_tile_slice_no_pool_node() {\n        let x = onnx_proto::make_tensor_value_info(\"input\", TensorProto::FLOAT, &[1, 3, 64, 64]);\n        let y = onnx_proto::make_tensor_value_info(\"output\", TensorProto::FLOAT, &[1, 3, 64, 64]);\n        let node =\n            onnx_proto::make_node(\"Relu\", vec![\"input\".into()], vec![\"output\".into()], vec![]);\n        let graph = onnx_proto::make_graph(\"no_pool\", vec![node], vec![x], vec![y], vec![]);\n        let model = onnx_proto::make_model(graph, 13);\n        let tmp = tempfile::tempdir().unwrap();\n        assert!(create_pool_tile_slice(&model, 16, 0, tmp.path()).is_err());\n    }\n\n    #[test]\n    fn estimate_slice_constraints_clamps_symbolic_dimensions() {\n        // ONNX serializes dynamic axes as -1 and placeholder axes as 0.\n        // Both must be clamped to 1 before forwarding to the jstprove\n        // estimator, otherwise product(shape) multiplies by zero and\n        // collapses the op's cost contribution to 0.\n        let node = NodeProto {\n            op_type: \"MatMul\".to_string(),\n            input: vec![\"input\".to_string(), \"weight\".to_string()],\n            output: vec![\"output\".to_string()],\n            ..Default::default()\n        };\n\n        let mut symbolic_shapes = HashMap::new();\n        symbolic_shapes.insert(\"input\".to_string(), vec![-1, 64]);\n        symbolic_shapes.insert(\"weight\".to_string(), vec![64, 128]);\n        symbolic_shapes.insert(\"output\".to_string(), vec![0, 128]);\n\n        let mut concrete_shapes = HashMap::new();\n        concrete_shapes.insert(\"input\".to_string(), vec![1, 64]);\n        concrete_shapes.insert(\"weight\".to_string(), vec![64, 128]);\n        concrete_shapes.insert(\"output\".to_string(), vec![1, 128]);\n\n        let nodes = [node];\n        let symbolic_cost = estimate_slice_constraints(&nodes, &symbolic_shapes);\n        let concrete_cost = estimate_slice_constraints(&nodes, &concrete_shapes);\n\n        assert!(\n            symbolic_cost > 0,\n            \"symbolic dims must not collapse cost to zero\"\n        );\n        assert_eq!(\n            symbolic_cost, concrete_cost,\n            \"batch -1 and batch 0 must clamp to 1 and match concrete batch 1\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/combiner.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::{Path, PathBuf};\n\nuse super::onnx_proto::{self, ModelProto, TensorProto, ValueInfoProto};\nuse crate::error::{DsperseError, Result};\nuse crate::schema::metadata::ModelMetadata;\n\npub fn materialize_combined_model(\n    model: &ModelProto,\n    metadata: &ModelMetadata,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    traced_types: Option<&HashMap<String, i32>>,\n) -> Result<ModelProto> {\n    let mut combined = model.clone();\n    let graph = combined\n        .graph\n        .as_mut()\n        .ok_or_else(|| DsperseError::Slicer(\"model.graph is None\".into()))?;\n\n    let existing_outputs: HashSet<String> = graph.output.iter().map(|o| o.name.clone()).collect();\n\n    let all_node_outputs: HashSet<String> = graph\n        .node\n        .iter()\n        .flat_map(|n| n.output.iter().cloned())\n        .collect();\n\n    let mut new_outputs: Vec<ValueInfoProto> = Vec::new();\n    let mut added: HashSet<String> = HashSet::new();\n\n    {\n        let vi_map = onnx_proto::build_value_info_map(graph);\n\n        for slice in &metadata.slices {\n            for output_name in &slice.dependencies.output {\n                if existing_outputs.contains(output_name) || added.contains(output_name) {\n                    continue;\n                }\n                if !all_node_outputs.contains(output_name) {\n                    tracing::warn!(\n                        tensor = %output_name,\n                        slice = slice.index,\n                        \"slice output not produced by any node in original graph, skipping\"\n                    );\n                    continue;\n                }\n\n                if let Some(vi) =\n                    resolve_value_info(output_name, &vi_map, traced_shapes, traced_types)?\n                {\n                    new_outputs.push(vi);\n                    added.insert(output_name.clone());\n                }\n            }\n\n            for input_name in &slice.dependencies.filtered_inputs {\n                if existing_outputs.contains(input_name) || added.contains(input_name) {\n                    continue;\n                }\n                if !all_node_outputs.contains(input_name) {\n                    tracing::debug!(\n                        tensor = %input_name,\n                        slice = slice.index,\n                        \"slice filtered_input not produced by any node in original graph, skipping\"\n                    );\n                    continue;\n                }\n\n                if let Some(vi) =\n                    resolve_value_info(input_name, &vi_map, traced_shapes, traced_types)?\n                {\n                    new_outputs.push(vi);\n                    added.insert(input_name.clone());\n                }\n            }\n        }\n    }\n\n    graph.output.extend(new_outputs);\n\n    tracing::info!(\n        intermediate_outputs = added.len(),\n        total_outputs = graph.output.len(),\n        \"combined model with slice boundary outputs\"\n    );\n\n    Ok(combined)\n}\n\nconst ONNX_STRING_DATATYPE: i32 = 8;\nconst NON_NUMERIC_TENSOR_TYPES: &[i32] = &[ONNX_STRING_DATATYPE];\n\nfn resolve_value_info(\n    name: &str,\n    vi_map: &HashMap<String, &ValueInfoProto>,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    traced_types: Option<&HashMap<String, i32>>,\n) -> Result<Option<ValueInfoProto>> {\n    if let Some(vi) = vi_map.get(name) {\n        let elem_type = onnx_proto::elem_type_from_value_info(vi).unwrap_or(TensorProto::FLOAT);\n        if NON_NUMERIC_TENSOR_TYPES.contains(&elem_type) {\n            return Ok(None);\n        }\n        return Ok(Some((*vi).clone()));\n    }\n\n    let shape = traced_shapes.get(name).ok_or_else(|| {\n        DsperseError::Slicer(format!(\n            \"no shape info for combined model output tensor '{name}'\"\n        ))\n    })?;\n\n    let elem_type = traced_types\n        .and_then(|t| t.get(name).copied())\n        .unwrap_or(TensorProto::FLOAT);\n\n    if NON_NUMERIC_TENSOR_TYPES.contains(&elem_type) {\n        return Ok(None);\n    }\n\n    Ok(Some(onnx_proto::make_tensor_value_info(\n        name, elem_type, shape,\n    )))\n}\n\npub fn ensure_combined_materialized(\n    slices_dir: &Path,\n    metadata: &ModelMetadata,\n) -> Result<PathBuf> {\n    let output_path = slices_dir.join(\"combined.onnx\");\n    if output_path.exists() {\n        return Ok(output_path);\n    }\n    materialize_combined_to_disk(slices_dir, metadata)\n}\n\npub fn materialize_combined_to_disk(\n    slices_dir: &Path,\n    metadata: &ModelMetadata,\n) -> Result<PathBuf> {\n    let traced_shapes = metadata.traced_shapes.as_ref().ok_or_else(|| {\n        DsperseError::Slicer(\"metadata missing traced_shapes for combined model\".into())\n    })?;\n    let traced_types = metadata.traced_types.as_ref();\n    let original_path = metadata.original_model_path.as_ref().ok_or_else(|| {\n        DsperseError::Slicer(\"metadata missing original_model_path for combined model\".into())\n    })?;\n\n    let model_path = if Path::new(original_path).is_absolute() {\n        std::path::PathBuf::from(original_path)\n    } else {\n        slices_dir.join(original_path)\n    };\n\n    let mut model = onnx_proto::load_model(&model_path)?;\n    onnx_proto::normalize_opset(&mut model);\n\n    let combined = materialize_combined_model(&model, metadata, traced_shapes, traced_types)?;\n\n    let dest = slices_dir.join(\"combined.onnx\");\n    onnx_proto::save_model(&combined, &dest)?;\n    tracing::info!(path = %dest.display(), \"materialized combined ONNX\");\n\n    Ok(dest)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::schema::metadata::{\n        Dependencies, ModelMetadata, SliceMetadata, SliceShapeWrapper, TensorShape,\n    };\n\n    fn make_test_model(\n        node_output_types: HashMap<String, i32>,\n        traced_shapes: HashMap<String, Vec<i64>>,\n    ) -> (ModelProto, ModelMetadata) {\n        let graph = onnx_proto::GraphProto {\n            node: vec![onnx_proto::NodeProto {\n                op_type: \"Identity\".to_string(),\n                input: vec![\"input\".to_string()],\n                output: vec![\n                    \"float_tensor\".to_string(),\n                    \"bool_tensor\".to_string(),\n                    \"string_tensor\".to_string(),\n                    \"int_tensor\".to_string(),\n                ],\n                ..Default::default()\n            }],\n            input: vec![onnx_proto::make_tensor_value_info(\n                \"input\",\n                TensorProto::FLOAT,\n                &[1, 3, 8, 8],\n            )],\n            output: vec![onnx_proto::make_tensor_value_info(\n                \"model_output\",\n                TensorProto::FLOAT,\n                &[1, 3, 8, 8],\n            )],\n            ..Default::default()\n        };\n        let model = onnx_proto::make_model(graph, 13);\n\n        let metadata = ModelMetadata {\n            slices: vec![SliceMetadata {\n                index: 0,\n                filename: \"s0.onnx\".to_string(),\n                path: \"s0.onnx\".to_string(),\n                relative_path: \"s0.onnx\".to_string(),\n                shape: SliceShapeWrapper {\n                    tensor_shape: TensorShape {\n                        input: vec![],\n                        output: vec![],\n                    },\n                },\n                dependencies: Dependencies {\n                    input: vec![],\n                    filtered_inputs: vec![],\n                    output: vec![\n                        \"float_tensor\".to_string(),\n                        \"bool_tensor\".to_string(),\n                        \"string_tensor\".to_string(),\n                        \"int_tensor\".to_string(),\n                    ],\n                },\n                ..Default::default()\n            }],\n            traced_shapes: Some(traced_shapes.clone()),\n            traced_types: Some(node_output_types),\n            ..Default::default()\n        };\n\n        (model, metadata)\n    }\n\n    #[test]\n    fn bool_outputs_included_in_combined_model() {\n        let mut node_output_types = HashMap::new();\n        node_output_types.insert(\"float_tensor\".to_string(), TensorProto::FLOAT);\n        node_output_types.insert(\"bool_tensor\".to_string(), TensorProto::BOOL);\n        node_output_types.insert(\"string_tensor\".to_string(), ONNX_STRING_DATATYPE);\n        node_output_types.insert(\"int_tensor\".to_string(), TensorProto::INT64);\n\n        let mut traced_shapes = HashMap::new();\n        traced_shapes.insert(\"float_tensor\".to_string(), vec![1, 3, 8, 8]);\n        traced_shapes.insert(\"bool_tensor\".to_string(), vec![1, 3, 8, 8]);\n        traced_shapes.insert(\"string_tensor\".to_string(), vec![1, 3, 8, 8]);\n        traced_shapes.insert(\"int_tensor\".to_string(), vec![1, 3, 8, 8]);\n\n        let (model, metadata) = make_test_model(node_output_types, traced_shapes.clone());\n\n        let traced_types = metadata.traced_types.as_ref();\n        let combined =\n            materialize_combined_model(&model, &metadata, &traced_shapes, traced_types).unwrap();\n\n        let graph = combined.graph.as_ref().unwrap();\n\n        let float_vi = graph.output.iter().find(|o| o.name == \"float_tensor\");\n        assert!(float_vi.is_some());\n\n        let bool_vi = graph.output.iter().find(|o| o.name == \"bool_tensor\");\n        assert!(bool_vi.is_some());\n\n        let string_vi = graph.output.iter().find(|o| o.name == \"string_tensor\");\n        assert!(\n            string_vi.is_none(),\n            \"string tensors should be excluded from combined outputs\"\n        );\n\n        let int_vi = graph.output.iter().find(|o| o.name == \"int_tensor\");\n        assert!(int_vi.is_some());\n    }\n\n    #[test]\n    fn combined_model_has_intermediate_outputs() {\n        let mut traced_shapes = HashMap::new();\n        traced_shapes.insert(\"float_tensor\".to_string(), vec![1, 3, 8, 8]);\n        traced_shapes.insert(\"bool_tensor\".to_string(), vec![1]);\n        traced_shapes.insert(\"string_tensor\".to_string(), vec![1]);\n        traced_shapes.insert(\"int_tensor\".to_string(), vec![2, 4]);\n\n        let mut types = HashMap::new();\n        types.insert(\"float_tensor\".to_string(), TensorProto::FLOAT);\n        types.insert(\"bool_tensor\".to_string(), TensorProto::BOOL);\n        types.insert(\"int_tensor\".to_string(), TensorProto::INT64);\n\n        let (model, metadata) = make_test_model(types, traced_shapes.clone());\n        let traced_types = metadata.traced_types.as_ref();\n        let combined =\n            materialize_combined_model(&model, &metadata, &traced_shapes, traced_types).unwrap();\n\n        let graph = combined.graph.as_ref().unwrap();\n        assert!(\n            graph.output.len() > 1,\n            \"combined model should have intermediate outputs\"\n        );\n    }\n\n    #[test]\n    fn combined_model_to_disk_roundtrip() {\n        let dir = tempfile::tempdir().unwrap();\n        let slices_dir = dir.path();\n\n        let mut traced_shapes = HashMap::new();\n        traced_shapes.insert(\"float_tensor\".to_string(), vec![1, 3, 8, 8]);\n        traced_shapes.insert(\"bool_tensor\".to_string(), vec![1]);\n        traced_shapes.insert(\"string_tensor\".to_string(), vec![1]);\n        traced_shapes.insert(\"int_tensor\".to_string(), vec![2, 4]);\n\n        let mut types = HashMap::new();\n        types.insert(\"float_tensor\".to_string(), TensorProto::FLOAT);\n        types.insert(\"bool_tensor\".to_string(), TensorProto::BOOL);\n        types.insert(\"int_tensor\".to_string(), TensorProto::INT64);\n\n        let (model, mut metadata) = make_test_model(types, traced_shapes);\n        metadata.original_model_path = Some(\"model.onnx\".to_string());\n\n        let model_path = slices_dir.join(\"model.onnx\");\n        onnx_proto::save_model(&model, &model_path).unwrap();\n        let meta_path = slices_dir.join(\"metadata.msgpack\");\n        metadata.save(&meta_path).unwrap();\n\n        let dest = materialize_combined_to_disk(slices_dir, &metadata).unwrap();\n        assert!(dest.exists());\n\n        let loaded = onnx_proto::load_model(&dest).unwrap();\n        let graph = loaded.graph.as_ref().unwrap();\n        assert!(\n            graph.output.len() > 1,\n            \"reloaded combined model should have intermediate outputs\"\n        );\n    }\n\n    #[test]\n    fn ensure_combined_is_idempotent() {\n        let dir = tempfile::tempdir().unwrap();\n        let slices_dir = dir.path();\n\n        let mut traced_shapes = HashMap::new();\n        traced_shapes.insert(\"float_tensor\".to_string(), vec![1, 3, 8, 8]);\n        traced_shapes.insert(\"bool_tensor\".to_string(), vec![1]);\n        traced_shapes.insert(\"string_tensor\".to_string(), vec![1]);\n        traced_shapes.insert(\"int_tensor\".to_string(), vec![2, 4]);\n\n        let mut types = HashMap::new();\n        types.insert(\"float_tensor\".to_string(), TensorProto::FLOAT);\n        types.insert(\"bool_tensor\".to_string(), TensorProto::BOOL);\n        types.insert(\"int_tensor\".to_string(), TensorProto::INT64);\n\n        let (model, mut metadata) = make_test_model(types, traced_shapes);\n        metadata.original_model_path = Some(\"model.onnx\".to_string());\n\n        let model_path = slices_dir.join(\"model.onnx\");\n        onnx_proto::save_model(&model, &model_path).unwrap();\n        let meta_path = slices_dir.join(\"metadata.msgpack\");\n        metadata.save(&meta_path).unwrap();\n\n        let dest1 = materialize_combined_to_disk(slices_dir, &metadata).unwrap();\n        let dest2 = materialize_combined_to_disk(slices_dir, &metadata).unwrap();\n        assert_eq!(dest1, dest2);\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/layernorm_fuse.rs",
    "content": "use std::collections::{HashMap, HashSet};\n\nuse super::onnx_proto::{\n    AttributeProto, ModelProto, NodeProto, TensorProto, tensor_to_f32, tensor_to_i64,\n};\n\npub fn fuse_inline_layernorms(\n    model: &mut ModelProto,\n    traced_shapes: &mut HashMap<String, Vec<i64>>,\n) -> usize {\n    let graph = match model.graph.as_mut() {\n        Some(g) => g,\n        None => return 0,\n    };\n\n    let initializers: HashMap<String, TensorProto> = graph\n        .initializer\n        .iter()\n        .map(|t| (t.name.clone(), t.clone()))\n        .collect();\n\n    let producers: HashMap<String, usize> = graph\n        .node\n        .iter()\n        .enumerate()\n        .flat_map(|(i, n)| {\n            n.output\n                .iter()\n                .filter(|o| !o.is_empty())\n                .map(move |o| (o.clone(), i))\n        })\n        .collect();\n\n    let mut consumers: HashMap<String, Vec<usize>> = HashMap::new();\n    for (i, n) in graph.node.iter().enumerate() {\n        for inp in &n.input {\n            if !inp.is_empty() {\n                consumers.entry(inp.clone()).or_default().push(i);\n            }\n        }\n    }\n\n    let mut drop: HashSet<usize> = HashSet::new();\n    let mut insertions: Vec<(usize, Vec<NodeProto>, Vec<TensorProto>)> = Vec::new();\n    let mut fused_id = 0usize;\n\n    for (mean_idx, mean_node) in graph.node.iter().enumerate() {\n        if drop.contains(&mean_idx) || mean_node.op_type != \"ReduceMean\" {\n            continue;\n        }\n        let Some(m) = try_match_layernorm(\n            mean_idx,\n            mean_node,\n            &graph.node,\n            &producers,\n            &consumers,\n            &initializers,\n            traced_shapes,\n            &drop,\n        ) else {\n            continue;\n        };\n        let (nodes, inits, shapes) = emit_replacement(&m, fused_id, &initializers);\n        for (name, shape) in shapes {\n            traced_shapes.insert(name, shape);\n        }\n        fused_id += 1;\n        drop.extend(m.nodes_to_drop.iter().copied());\n        insertions.push((mean_idx, nodes, inits));\n    }\n\n    let fused = insertions.len();\n    if fused == 0 {\n        return 0;\n    }\n\n    for (_, _, inits) in &insertions {\n        for t in inits {\n            graph.initializer.push(t.clone());\n        }\n    }\n\n    let insertion_map: HashMap<usize, Vec<NodeProto>> = insertions\n        .into_iter()\n        .map(|(idx, nodes, _)| (idx, nodes))\n        .collect();\n\n    let mut new_nodes: Vec<NodeProto> = Vec::with_capacity(graph.node.len());\n    for (i, n) in graph.node.drain(..).enumerate() {\n        if let Some(inserts) = insertion_map.get(&i) {\n            new_nodes.extend(inserts.iter().cloned());\n            continue;\n        }\n        if drop.contains(&i) {\n            continue;\n        }\n        new_nodes.push(n);\n    }\n    graph.node = new_nodes;\n    fused\n}\n\nstruct MatchedPattern {\n    x_name: String,\n    axes: Vec<usize>,\n    rank: usize,\n    x_shape: Vec<i64>,\n    eps: f32,\n    scale_init: Option<String>,\n    bias_init: Option<String>,\n    output_name: String,\n    nodes_to_drop: Vec<usize>,\n}\n\n#[allow(clippy::too_many_arguments, clippy::too_many_lines)]\nfn try_match_layernorm(\n    mean_idx: usize,\n    mean_node: &NodeProto,\n    nodes: &[NodeProto],\n    producers: &HashMap<String, usize>,\n    consumers: &HashMap<String, Vec<usize>>,\n    initializers: &HashMap<String, TensorProto>,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    drop: &HashSet<usize>,\n) -> Option<MatchedPattern> {\n    let raw_axes = reduce_axes(mean_node, initializers)?;\n    if get_keepdims(mean_node).unwrap_or(1) != 1 {\n        return None;\n    }\n    let x_name = mean_node.input.first()?.clone();\n    let mean_out = mean_node.output.first()?.clone();\n\n    let sub_idx = find_unique_consumer(consumers, &mean_out, \"Sub\", nodes, drop)?;\n    let sub_node = &nodes[sub_idx];\n    if sub_node.input.len() < 2\n        || sub_node.input.first()? != &x_name\n        || sub_node.input.get(1)? != &mean_out\n    {\n        return None;\n    }\n    let centered = sub_node.output.first()?.clone();\n\n    let sq_idx = find_square_consumer(consumers, &centered, nodes, initializers, drop)?;\n    let sq_node = &nodes[sq_idx];\n    let sq_out = sq_node.output.first()?.clone();\n\n    let mean2_idx = find_unique_consumer(consumers, &sq_out, \"ReduceMean\", nodes, drop)?;\n    let mean2_node = &nodes[mean2_idx];\n    let raw_axes2 = reduce_axes(mean2_node, initializers)?;\n    if raw_axes2 != raw_axes {\n        return None;\n    }\n    if get_keepdims(mean2_node).unwrap_or(1) != 1 {\n        return None;\n    }\n    let var_out = mean2_node.output.first()?.clone();\n\n    let add_idx = find_unique_consumer(consumers, &var_out, \"Add\", nodes, drop)?;\n    let add_node = &nodes[add_idx];\n    let eps = extract_binary_const_scalar(add_node, &var_out, initializers)?;\n    let var_eps = add_node.output.first()?.clone();\n\n    let sqrt_idx = find_unique_consumer(consumers, &var_eps, \"Sqrt\", nodes, drop)?;\n    let sqrt_node = &nodes[sqrt_idx];\n    let std_out = sqrt_node.output.first()?.clone();\n\n    let div_idx = find_unique_consumer(consumers, &std_out, \"Div\", nodes, drop)?;\n    let div_node = &nodes[div_idx];\n    if div_node.input.len() < 2\n        || div_node.input.first()? != &centered\n        || div_node.input.get(1)? != &std_out\n    {\n        return None;\n    }\n    let norm_out = div_node.output.first()?.clone();\n\n    let mut nodes_to_drop = vec![\n        mean_idx, sub_idx, sq_idx, mean2_idx, add_idx, sqrt_idx, div_idx,\n    ];\n    let mut output_name = norm_out.clone();\n    let mut scale_init: Option<String> = None;\n    let mut bias_init: Option<String> = None;\n\n    if let Some(mul_idx) = find_unique_consumer(consumers, &norm_out, \"Mul\", nodes, drop) {\n        let mul_node = &nodes[mul_idx];\n        if let Some(scale) = other_input_if_init(mul_node, &norm_out, initializers) {\n            scale_init = Some(scale);\n            output_name = mul_node.output.first()?.clone();\n            nodes_to_drop.push(mul_idx);\n\n            if let Some(add2_idx) =\n                find_unique_consumer(consumers, &output_name, \"Add\", nodes, drop)\n            {\n                let add2_node = &nodes[add2_idx];\n                if let Some(bias) = other_input_if_init(add2_node, &output_name, initializers) {\n                    bias_init = Some(bias);\n                    output_name = add2_node.output.first()?.clone();\n                    nodes_to_drop.push(add2_idx);\n                }\n            }\n        }\n    }\n\n    // Soundness check: every intermediate tensor we are about to drop\n    // (mean_out, centered, sq_out, var_out, var_eps, std_out, plus the\n    // pre-affine norm_out when scale/bias are present) must have all\n    // its live consumers inside nodes_to_drop.  Otherwise some\n    // downstream node still reads the intermediate and fusing would\n    // disconnect it.\n    let drop_set: HashSet<usize> = nodes_to_drop.iter().copied().collect();\n    let mut intermediates: Vec<&str> = vec![\n        mean_out.as_str(),\n        centered.as_str(),\n        sq_out.as_str(),\n        var_out.as_str(),\n        var_eps.as_str(),\n        std_out.as_str(),\n    ];\n    if scale_init.is_some() {\n        intermediates.push(norm_out.as_str());\n    }\n    for tname in intermediates {\n        if let Some(list) = consumers.get(tname) {\n            for &idx in list {\n                if drop.contains(&idx) || drop_set.contains(&idx) {\n                    continue;\n                }\n                return None;\n            }\n        }\n    }\n\n    let x_shape = resolve_shape(&x_name, traced_shapes, initializers, nodes, producers)?;\n    let rank = x_shape.len();\n    if rank == 0 {\n        return None;\n    }\n    let axes: Vec<usize> = raw_axes.iter().map(|&a| normalize_axis(a, rank)).collect();\n    for &a in &axes {\n        if a >= rank {\n            return None;\n        }\n        // Reject dynamic / unresolved dims along the reduction axes: the\n        // fused LayerNormalization circuit needs a concrete lane_size\n        // and consumers of m.x_shape[a] later cast the dim to usize,\n        // which silently wraps negative sentinels into huge values.\n        if x_shape[a] <= 0 {\n            return None;\n        }\n    }\n\n    Some(MatchedPattern {\n        x_name,\n        axes,\n        rank,\n        x_shape,\n        eps,\n        scale_init,\n        bias_init,\n        output_name,\n        nodes_to_drop,\n    })\n}\n\nfn resolve_shape(\n    name: &str,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    initializers: &HashMap<String, TensorProto>,\n    _nodes: &[NodeProto],\n    _producers: &HashMap<String, usize>,\n) -> Option<Vec<i64>> {\n    if let Some(s) = traced_shapes.get(name)\n        && !s.is_empty()\n    {\n        return Some(s.clone());\n    }\n    if let Some(t) = initializers.get(name) {\n        return Some(t.dims.clone());\n    }\n    None\n}\n\nfn reduce_axes(node: &NodeProto, initializers: &HashMap<String, TensorProto>) -> Option<Vec<i64>> {\n    if let Some(attr) = node.attribute.iter().find(|a| a.name == \"axes\")\n        && !attr.ints.is_empty()\n    {\n        return Some(attr.ints.clone());\n    }\n    if let Some(name) = node.input.get(1)\n        && let Some(t) = initializers.get(name)\n    {\n        let v = tensor_to_i64(t);\n        if !v.is_empty() {\n            return Some(v);\n        }\n    }\n    None\n}\n\nfn get_keepdims(node: &NodeProto) -> Option<i64> {\n    node.attribute\n        .iter()\n        .find(|a| a.name == \"keepdims\")\n        .map(|a| a.i)\n}\n\nfn find_unique_consumer(\n    consumers: &HashMap<String, Vec<usize>>,\n    tensor: &str,\n    op_type: &str,\n    nodes: &[NodeProto],\n    drop: &HashSet<usize>,\n) -> Option<usize> {\n    let list = consumers.get(tensor)?;\n    let live: Vec<usize> = list.iter().copied().filter(|i| !drop.contains(i)).collect();\n    if live.len() != 1 {\n        return None;\n    }\n    let idx = live[0];\n    (nodes[idx].op_type == op_type).then_some(idx)\n}\n\nfn find_square_consumer(\n    consumers: &HashMap<String, Vec<usize>>,\n    tensor: &str,\n    nodes: &[NodeProto],\n    initializers: &HashMap<String, TensorProto>,\n    drop: &HashSet<usize>,\n) -> Option<usize> {\n    // The centered tensor in the inline-LN pattern has TWO legitimate\n    // consumers: Pow / Mul (for the variance branch) AND Div (for the\n    // normalization branch).  Both belong to the fusion -- don't reject\n    // them as orphan consumers.  Final orphan-leak check happens after\n    // the whole pattern matches in try_match_layernorm.\n    let list = consumers.get(tensor)?;\n    for &idx in list.iter().filter(|i| !drop.contains(i)) {\n        let n = &nodes[idx];\n        match n.op_type.as_str() {\n            \"Pow\" => {\n                if n.input.len() >= 2\n                    && n.input.first().map(String::as_str) == Some(tensor)\n                    && pow_exponent_is_two(n.input.get(1)?, initializers)\n                {\n                    return Some(idx);\n                }\n            }\n            \"Mul\" => {\n                if n.input.len() == 2 && n.input.iter().all(|i| i == tensor) {\n                    return Some(idx);\n                }\n            }\n            _ => {}\n        }\n    }\n    None\n}\n\nfn pow_exponent_is_two(name: &str, initializers: &HashMap<String, TensorProto>) -> bool {\n    let Some(t) = initializers.get(name) else {\n        return false;\n    };\n    let f = tensor_to_f32(t);\n    if let Some(&v) = f.first()\n        && (v - 2.0).abs() < f32::EPSILON\n    {\n        return true;\n    }\n    let i = tensor_to_i64(t);\n    matches!(i.first(), Some(&2))\n}\n\nfn extract_binary_const_scalar(\n    node: &NodeProto,\n    non_const_input: &str,\n    initializers: &HashMap<String, TensorProto>,\n) -> Option<f32> {\n    if node.input.len() != 2 {\n        return None;\n    }\n    let (a, b) = (node.input.first()?, node.input.get(1)?);\n    let other_name = if a.as_str() == non_const_input {\n        b\n    } else if b.as_str() == non_const_input {\n        a\n    } else {\n        return None;\n    };\n    let t = initializers.get(other_name)?;\n    tensor_to_f32(t).first().copied()\n}\n\nfn other_input_if_init(\n    node: &NodeProto,\n    non_const_input: &str,\n    initializers: &HashMap<String, TensorProto>,\n) -> Option<String> {\n    if node.input.len() != 2 {\n        return None;\n    }\n    let a = node.input.first()?.clone();\n    let b = node.input.get(1)?.clone();\n    let other = if a == non_const_input {\n        b\n    } else if b == non_const_input {\n        a\n    } else {\n        return None;\n    };\n    initializers.get(&other).map(|_| other)\n}\n\ntype ReplacementShapes = Vec<(String, Vec<i64>)>;\ntype Replacement = (Vec<NodeProto>, Vec<TensorProto>, ReplacementShapes);\n\nfn emit_replacement(\n    m: &MatchedPattern,\n    fused_id: usize,\n    initializers: &HashMap<String, TensorProto>,\n) -> Replacement {\n    let rank = m.rank;\n    let axes_set: HashSet<usize> = m.axes.iter().copied().collect();\n    let mut forward_perm: Vec<i64> = (0..rank)\n        .filter(|d| !axes_set.contains(d))\n        .map(|d| d as i64)\n        .collect();\n    for &a in &m.axes {\n        forward_perm.push(a as i64);\n    }\n    let mut inverse_perm: Vec<i64> = vec![0; rank];\n    for (new_pos, &old_pos) in forward_perm.iter().enumerate() {\n        inverse_perm[old_pos as usize] = new_pos as i64;\n    }\n\n    let lane_size: usize = m.axes.iter().map(|&a| m.x_shape[a] as usize).product();\n\n    let prefix = format!(\"/__dsperse/fused_ln_{fused_id}\");\n    let xt_name = format!(\"{prefix}/xt\");\n    let yt_name = format!(\"{prefix}/yt\");\n\n    let (scale_name, scale_init_opt) = materialize_1d_initializer(\n        &format!(\"{prefix}/scale\"),\n        m.scale_init.as_deref(),\n        initializers,\n        lane_size,\n        1.0,\n    );\n    let (bias_name, bias_init_opt) = materialize_1d_initializer(\n        &format!(\"{prefix}/bias\"),\n        m.bias_init.as_deref(),\n        initializers,\n        lane_size,\n        0.0,\n    );\n\n    let mut nodes = Vec::new();\n\n    nodes.push(NodeProto {\n        name: format!(\"{prefix}/Transpose_in\"),\n        op_type: \"Transpose\".to_string(),\n        input: vec![m.x_name.clone()],\n        output: vec![xt_name.clone()],\n        attribute: vec![int_list_attr(\"perm\", &forward_perm)],\n        ..Default::default()\n    });\n\n    let ln_axis = (rank - m.axes.len()) as i64;\n    nodes.push(NodeProto {\n        name: format!(\"{prefix}/LayerNormalization\"),\n        op_type: \"LayerNormalization\".to_string(),\n        input: vec![xt_name, scale_name, bias_name],\n        output: vec![yt_name.clone()],\n        attribute: vec![int_attr(\"axis\", ln_axis), float_attr(\"epsilon\", m.eps)],\n        ..Default::default()\n    });\n\n    nodes.push(NodeProto {\n        name: format!(\"{prefix}/Transpose_out\"),\n        op_type: \"Transpose\".to_string(),\n        input: vec![yt_name],\n        output: vec![m.output_name.clone()],\n        attribute: vec![int_list_attr(\"perm\", &inverse_perm)],\n        ..Default::default()\n    });\n\n    let mut inits = Vec::new();\n    inits.extend(scale_init_opt);\n    inits.extend(bias_init_opt);\n\n    let xt_shape: Vec<i64> = forward_perm\n        .iter()\n        .map(|&p| m.x_shape[p as usize])\n        .collect();\n    let yt_shape = xt_shape.clone();\n    let shapes = vec![\n        (format!(\"{prefix}/xt\"), xt_shape),\n        (format!(\"{prefix}/yt\"), yt_shape),\n    ];\n\n    (nodes, inits, shapes)\n}\n\nfn materialize_1d_initializer(\n    new_name: &str,\n    source: Option<&str>,\n    initializers: &HashMap<String, TensorProto>,\n    lane_size: usize,\n    default_fill: f32,\n) -> (String, Option<TensorProto>) {\n    let Some(src) = source else {\n        return (\n            new_name.to_string(),\n            Some(const_vector(new_name, lane_size, default_fill)),\n        );\n    };\n    let Some(t) = initializers.get(src) else {\n        return (\n            new_name.to_string(),\n            Some(const_vector(new_name, lane_size, default_fill)),\n        );\n    };\n    let elems = tensor_to_f32(t);\n    if elems.len() == lane_size && t.dims.len() == 1 {\n        return (src.to_string(), None);\n    }\n    let vals: Vec<f32> = if elems.len() == lane_size {\n        elems\n    } else if elems.len() == 1 {\n        vec![elems[0]; lane_size]\n    } else {\n        return (\n            new_name.to_string(),\n            Some(const_vector(new_name, lane_size, default_fill)),\n        );\n    };\n    (new_name.to_string(), Some(make_f32_vector(new_name, &vals)))\n}\n\nfn const_vector(name: &str, len: usize, fill: f32) -> TensorProto {\n    make_f32_vector(name, &vec![fill; len])\n}\n\nfn make_f32_vector(name: &str, vals: &[f32]) -> TensorProto {\n    TensorProto {\n        name: name.to_string(),\n        data_type: TensorProto::FLOAT,\n        dims: vec![vals.len() as i64],\n        float_data: vals.to_vec(),\n        ..Default::default()\n    }\n}\n\nfn normalize_axis(axis: i64, rank: usize) -> usize {\n    if axis < 0 {\n        (rank as i64 + axis) as usize\n    } else {\n        axis as usize\n    }\n}\n\nfn int_attr(name: &str, v: i64) -> AttributeProto {\n    AttributeProto {\n        name: name.to_string(),\n        r#type: 2,\n        i: v,\n        ..Default::default()\n    }\n}\n\nfn float_attr(name: &str, v: f32) -> AttributeProto {\n    AttributeProto {\n        name: name.to_string(),\n        r#type: 1,\n        f: v,\n        ..Default::default()\n    }\n}\n\nfn int_list_attr(name: &str, vals: &[i64]) -> AttributeProto {\n    AttributeProto {\n        name: name.to_string(),\n        r#type: 7,\n        ints: vals.to_vec(),\n        ..Default::default()\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/materializer.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::{Path, PathBuf};\n\nuse super::autotiler::{self, ChannelSplitParams};\nuse super::onnx_proto::{self, GraphProto, ModelProto, NodeProto, TensorProto, ValueInfoProto};\nuse super::onnx_slicer::broadcast_shapes;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::metadata::ModelMetadata;\n\nconst MAX_BACKWARD_DEPTH: usize = 64;\n\nfn resolve_shape_backward(\n    tensor_name: &str,\n    graph: &GraphProto,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n) -> Option<Vec<i64>> {\n    resolve_shape_backward_inner(tensor_name, graph, traced_shapes, 0)\n}\n\nfn resolve_shape_backward_inner(\n    tensor_name: &str,\n    graph: &GraphProto,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    depth: usize,\n) -> Option<Vec<i64>> {\n    if depth > MAX_BACKWARD_DEPTH {\n        return None;\n    }\n\n    if let Some(s) = traced_shapes.get(tensor_name) {\n        return Some(s.clone());\n    }\n\n    if let Some(vi) = graph.value_info.iter().find(|v| v.name == tensor_name)\n        && let Some(shape) = onnx_proto::shape_from_value_info(vi)\n    {\n        return Some(shape);\n    }\n\n    for init in &graph.initializer {\n        if init.name == tensor_name {\n            return Some(init.dims.to_vec());\n        }\n    }\n\n    let producer = graph\n        .node\n        .iter()\n        .find(|n| n.output.contains(&tensor_name.to_string()))?;\n    let op = producer.op_type.as_str();\n\n    if super::is_shape_preserving(op) {\n        let inp = producer.input.first()?;\n        return resolve_shape_backward_inner(inp, graph, traced_shapes, depth + 1);\n    }\n\n    if op == \"Shape\" {\n        let inp = producer.input.first()?;\n        let in_shape = resolve_shape_backward_inner(inp, graph, traced_shapes, depth + 1)?;\n        return Some(vec![in_shape.len() as i64]);\n    }\n\n    if super::is_binary_arithmetic(op) {\n        let resolved: Vec<Vec<i64>> = producer\n            .input\n            .iter()\n            .filter_map(|inp| resolve_shape_backward_inner(inp, graph, traced_shapes, depth + 1))\n            .collect();\n        let refs: Vec<&Vec<i64>> = resolved.iter().collect();\n        if let Some(broadcasted) = broadcast_shapes(&refs) {\n            return Some(broadcasted);\n        }\n    }\n\n    None\n}\n\npub fn materialize_slice_model(\n    model: &ModelProto,\n    slice_points: &[usize],\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    traced_types: &HashMap<String, i32>,\n    slice_idx: usize,\n) -> Result<ModelProto> {\n    let graph = model\n        .graph\n        .as_ref()\n        .ok_or_else(|| DsperseError::Slicer(\"model.graph is None\".into()))?;\n\n    let total_nodes = graph.node.len();\n    let segment_ranges = super::build_segment_ranges(slice_points, Some(total_nodes));\n    let &(start, end) = segment_ranges.get(slice_idx).ok_or_else(|| {\n        DsperseError::Slicer(format!(\n            \"slice index {slice_idx} out of range (have {} segments)\",\n            segment_ranges.len()\n        ))\n    })?;\n\n    let init_map: HashMap<&str, &TensorProto> = graph\n        .initializer\n        .iter()\n        .map(|i| (i.name.as_str(), i))\n        .collect();\n\n    let vi_map = onnx_proto::build_value_info_map(graph);\n\n    let init_types: HashMap<&str, i32> = graph\n        .initializer\n        .iter()\n        .map(|i| (i.name.as_str(), i.data_type))\n        .collect();\n\n    let node_output_types = build_node_output_types(graph);\n    let future_deps = compute_future_dependencies(graph, &segment_ranges, &init_map);\n\n    let constant_producers: HashMap<String, &TensorProto> = graph\n        .node\n        .iter()\n        .filter(|n| n.op_type == \"Constant\")\n        .flat_map(|n| {\n            n.output.iter().filter_map(|out| {\n                n.attribute\n                    .iter()\n                    .find(|a| a.name == \"value\")\n                    .and_then(|a| a.t.as_ref())\n                    .map(|t| (out.clone(), t))\n            })\n        })\n        .collect();\n\n    let nodes: Vec<NodeProto> = graph.node[start..end].to_vec();\n\n    let seg_outputs: HashSet<String> = nodes\n        .iter()\n        .flat_map(|n| n.output.iter().cloned())\n        .collect();\n\n    let seg_inputs_set: HashSet<String> = nodes\n        .iter()\n        .flat_map(|n| {\n            let mut inputs: Vec<String> =\n                n.input.iter().filter(|s| !s.is_empty()).cloned().collect();\n            if super::is_control_flow(&n.op_type) {\n                let outer_refs = super::collect_subgraph_outer_refs(n, graph);\n                inputs.extend(outer_refs);\n            }\n            inputs\n        })\n        .collect();\n\n    let future = future_deps.get(&slice_idx).cloned().unwrap_or_default();\n\n    let query = SegmentQuery {\n        nodes: &nodes,\n        seg_outputs: &seg_outputs,\n        seg_inputs_set: &seg_inputs_set,\n        future_inputs: &future,\n    };\n    let ctx = ShapeContext {\n        graph,\n        init_map: &init_map,\n        vi_map: &vi_map,\n        traced_shapes,\n        traced_types,\n        init_types: &init_types,\n        node_output_types: &node_output_types,\n        constant_producers: &constant_producers,\n    };\n    let (inputs, outputs, initializers) = get_segment_details(&query, &ctx)?;\n\n    let opset_version = model\n        .opset_import\n        .iter()\n        .find(|o| o.domain.is_empty() || o.domain == \"ai.onnx\")\n        .map(|o| o.version)\n        .unwrap_or(13);\n\n    let seg_graph = onnx_proto::make_graph(\n        &format!(\"segment_{slice_idx}_graph\"),\n        nodes,\n        inputs,\n        outputs,\n        initializers,\n    );\n    Ok(onnx_proto::make_model(seg_graph, opset_version))\n}\n\npub fn materialize_slice_to_disk(\n    model: &ModelProto,\n    slice_points: &[usize],\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    traced_types: &HashMap<String, i32>,\n    slice_idx: usize,\n    output_path: &Path,\n) -> Result<PathBuf> {\n    let slice_model =\n        materialize_slice_model(model, slice_points, traced_shapes, traced_types, slice_idx)?;\n    if let Some(parent) = output_path.parent() {\n        std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?;\n    }\n    onnx_proto::save_model(&slice_model, output_path)?;\n    Ok(output_path.to_path_buf())\n}\n\npub fn ensure_slice_materialized(\n    slices_dir: &Path,\n    metadata: &ModelMetadata,\n    slice_idx: usize,\n) -> Result<PathBuf> {\n    let slice_meta = metadata\n        .slices\n        .iter()\n        .find(|s| s.index == slice_idx)\n        .ok_or_else(|| DsperseError::Slicer(format!(\"no slice metadata for index {slice_idx}\")))?;\n\n    let slice_dir = slices_dir.join(format!(\"slice_{slice_idx}\"));\n    let payload_dir = slice_dir.join(\"payload\");\n    let onnx_path = payload_dir.join(format!(\"slice_{slice_idx}.onnx\"));\n\n    if onnx_path.exists() {\n        materialize_tiling_artifacts(slices_dir, metadata, slice_meta, slice_idx)?;\n        return Ok(onnx_path);\n    }\n\n    if !slice_dir.exists() {\n        let archive = slices_dir.join(format!(\"slice_{slice_idx}.dslice\"));\n        if archive.exists() {\n            extract_dslice_archive(&archive, &slice_dir)?;\n            if onnx_path.exists() {\n                materialize_tiling_artifacts(slices_dir, metadata, slice_meta, slice_idx)?;\n                return Ok(onnx_path);\n            }\n        }\n    }\n\n    let traced_shapes = metadata.traced_shapes.as_ref().ok_or_else(|| {\n        DsperseError::Slicer(\"metadata missing traced_shapes for materialization\".into())\n    })?;\n    let empty_types: HashMap<String, i32> = HashMap::new();\n    let traced_types = metadata.traced_types.as_ref().unwrap_or(&empty_types);\n    let original_path = metadata.original_model_path.as_ref().ok_or_else(|| {\n        DsperseError::Slicer(\"metadata missing original_model_path for materialization\".into())\n    })?;\n\n    let model_path = if Path::new(original_path).is_absolute() {\n        PathBuf::from(original_path)\n    } else {\n        slices_dir.join(original_path)\n    };\n\n    let mut model = onnx_proto::load_model(&model_path)?;\n    onnx_proto::normalize_opset(&mut model);\n    let model_with_shapes = apply_traced_shapes(model, traced_shapes);\n\n    std::fs::create_dir_all(&payload_dir).map_err(|e| DsperseError::io(e, &payload_dir))?;\n    materialize_slice_to_disk(\n        &model_with_shapes,\n        &metadata.slice_points,\n        traced_shapes,\n        traced_types,\n        slice_idx,\n        &onnx_path,\n    )?;\n\n    tracing::info!(slice = slice_idx, path = %onnx_path.display(), \"materialized slice\");\n\n    materialize_tiling_artifacts(slices_dir, metadata, slice_meta, slice_idx)?;\n\n    Ok(onnx_path)\n}\n\nfn materialize_tiling_artifacts(\n    slices_dir: &Path,\n    metadata: &ModelMetadata,\n    slice_meta: &crate::schema::metadata::SliceMetadata,\n    slice_idx: usize,\n) -> Result<()> {\n    let payload_dir = slices_dir\n        .join(format!(\"slice_{slice_idx}\"))\n        .join(\"payload\");\n\n    if let Some(ref tiling) = slice_meta.tiling\n        && let Some(ref tile) = tiling.tile\n    {\n        let tile_path = slices_dir.join(&tile.path);\n        if !tile_path.exists() {\n            let onnx_path = payload_dir.join(format!(\"slice_{slice_idx}.onnx\"));\n            let slice_model = onnx_proto::load_model(&onnx_path)?;\n            let is_ew = slice_model.graph.as_ref().is_some_and(|g| {\n                !g.node.is_empty() && g.node.iter().all(|n| super::is_elementwise(&n.op_type))\n            });\n            let is_pool = slice_model\n                .graph\n                .as_ref()\n                .is_some_and(|g| g.node.iter().any(|n| n.op_type == \"MaxPool\"));\n            if is_ew {\n                let seg_size = tiling.segment_size.ok_or_else(|| {\n                    crate::error::DsperseError::Slicer(format!(\n                        \"slice {slice_idx}: elementwise tiling metadata missing segment_size; re-slice the model\"\n                    ))\n                })? as i64;\n                autotiler::create_elementwise_tile_slice(\n                    &slice_model,\n                    seg_size,\n                    slice_idx,\n                    &payload_dir,\n                )?;\n            } else if is_pool {\n                autotiler::create_pool_tile_slice(\n                    &slice_model,\n                    tiling.tile_size as i64,\n                    slice_idx,\n                    &payload_dir,\n                )?;\n            } else {\n                autotiler::create_tile_slice(\n                    &slice_model,\n                    tiling.tile_size as i64,\n                    slice_idx,\n                    &payload_dir,\n                )?;\n            }\n            tracing::info!(slice = slice_idx, \"materialized tile ONNX\");\n        }\n    }\n\n    if let Some(ref cs) = slice_meta.channel_split {\n        let needs_materialization = cs.groups.is_empty()\n            || cs.groups.iter().any(|g| {\n                let group_path = slices_dir.join(&g.path);\n                !group_path.exists()\n            });\n\n        if needs_materialization && cs.num_groups > 0 {\n            let onnx_path = payload_dir.join(format!(\"slice_{slice_idx}.onnx\"));\n            let slice_model = onnx_proto::load_model(&onnx_path)?;\n\n            let params = ChannelSplitParams {\n                c_in: cs.c_in as i64,\n                c_out: cs.c_out as i64,\n                num_groups: cs.num_groups as i64,\n                channels_per_group: cs.channels_per_group as i64,\n                h: cs.h as i64,\n                w: cs.w as i64,\n                slice_idx,\n            };\n            let cs_info = autotiler::apply_channel_splitting(\n                &slice_model,\n                &params,\n                &cs.input_name,\n                &cs.output_name,\n                &payload_dir,\n            )?;\n            tracing::info!(\n                slice = slice_idx,\n                groups = cs_info.groups.len(),\n                \"materialized channel groups\"\n            );\n        }\n    }\n\n    if let Some(ref ds) = slice_meta.dim_split\n        && ds.num_groups > 0\n    {\n        let tmpl_path = payload_dir.join(\"dim_template.onnx\");\n        if !tmpl_path.exists() {\n            let onnx_path = payload_dir.join(format!(\"slice_{slice_idx}.onnx\"));\n            let slice_model = onnx_proto::load_model(&onnx_path)?;\n            match autotiler::create_dim_split_template(\n                &slice_model,\n                ds,\n                &payload_dir,\n                metadata.traced_shapes.as_ref(),\n            ) {\n                Ok(_) => {\n                    tracing::info!(slice = slice_idx, \"materialized dim-split template\");\n                }\n                Err(e) => {\n                    tracing::info!(\n                        slice = slice_idx,\n                        error = %e,\n                        \"dim-split template skipped, will compile as single slice\"\n                    );\n                }\n            }\n        }\n    }\n\n    Ok(())\n}\n\npub fn ensure_all_slices_materialized(slices_dir: &Path, metadata: &ModelMetadata) -> Result<()> {\n    use rayon::prelude::*;\n\n    metadata.slices.par_iter().try_for_each(|slice| {\n        ensure_slice_materialized(slices_dir, metadata, slice.index).map(|_| ())\n    })\n}\n\nfn apply_traced_shapes(mut model: ModelProto, shapes: &HashMap<String, Vec<i64>>) -> ModelProto {\n    fn set_shape(vi: &mut ValueInfoProto, shape: &[i64]) {\n        if let Some(ref mut tp) = vi.r#type\n            && let Some(onnx_proto::onnx::type_proto::Value::TensorType(ref mut tt)) = tp.value\n        {\n            tt.shape = Some(onnx_proto::onnx::TensorShapeProto {\n                dim: shape\n                    .iter()\n                    .map(|&d| onnx_proto::onnx::tensor_shape_proto::Dimension {\n                        denotation: String::new(),\n                        value: Some(\n                            onnx_proto::onnx::tensor_shape_proto::dimension::Value::DimValue(d),\n                        ),\n                    })\n                    .collect(),\n            });\n        }\n    }\n\n    if let Some(ref mut graph) = model.graph {\n        for inp in &mut graph.input {\n            if let Some(shape) = shapes.get(&inp.name) {\n                set_shape(inp, shape);\n            }\n        }\n        for out in &mut graph.output {\n            if let Some(shape) = shapes.get(&out.name) {\n                set_shape(out, shape);\n            }\n        }\n        for vi in &mut graph.value_info {\n            if let Some(shape) = shapes.get(&vi.name) {\n                set_shape(vi, shape);\n            }\n        }\n\n        let existing: HashSet<String> = graph\n            .input\n            .iter()\n            .chain(graph.output.iter())\n            .chain(graph.value_info.iter())\n            .map(|vi| vi.name.clone())\n            .collect();\n\n        let init_types: HashMap<&str, i32> = graph\n            .initializer\n            .iter()\n            .map(|i| (i.name.as_str(), i.data_type))\n            .collect();\n\n        let node_output_types = build_node_output_types(graph);\n\n        for (name, shape) in shapes {\n            if !existing.contains(name) {\n                let from_init = init_types.get(name.as_str()).copied();\n                let from_node = node_output_types.get(name).copied();\n                let elem_type = from_init.or(from_node).unwrap_or_else(|| {\n                    tracing::debug!(tensor = %name, \"no explicit dtype in initializers or node outputs, assuming FLOAT\");\n                    TensorProto::FLOAT\n                });\n                graph\n                    .value_info\n                    .push(onnx_proto::make_tensor_value_info(name, elem_type, shape));\n            }\n        }\n    }\n    model\n}\n\nfn compute_future_dependencies(\n    graph: &GraphProto,\n    segment_ranges: &[(usize, usize)],\n    init_map: &HashMap<&str, &TensorProto>,\n) -> HashMap<usize, HashSet<String>> {\n    let mut seg_inputs: HashMap<usize, HashSet<String>> = HashMap::new();\n\n    for (seg_idx, &(start, end)) in segment_ranges.iter().enumerate() {\n        let seg_outputs: HashSet<String> = graph.node[start..end]\n            .iter()\n            .flat_map(|n| n.output.iter().cloned())\n            .collect();\n\n        let inputs: HashSet<String> = graph.node[start..end]\n            .iter()\n            .flat_map(|n| {\n                if super::is_control_flow(&n.op_type) {\n                    let outer_refs = super::collect_subgraph_outer_refs(n, graph);\n                    return outer_refs\n                        .into_iter()\n                        .chain(n.input.iter().cloned())\n                        .collect::<Vec<String>>();\n                }\n                n.input.to_vec()\n            })\n            .filter(|inp| {\n                !inp.is_empty()\n                    && !seg_outputs.contains(inp.as_str())\n                    && !init_map.contains_key(inp.as_str())\n            })\n            .collect();\n\n        seg_inputs.insert(seg_idx, inputs);\n    }\n\n    let mut future: HashMap<usize, HashSet<String>> = HashMap::new();\n    for seg_idx in 0..segment_ranges.len() {\n        let mut deps = HashSet::new();\n        for future_idx in (seg_idx + 1)..segment_ranges.len() {\n            if let Some(inputs) = seg_inputs.get(&future_idx) {\n                deps.extend(inputs.iter().cloned());\n            }\n        }\n        future.insert(seg_idx, deps);\n    }\n    future\n}\n\nstruct SegmentQuery<'a> {\n    nodes: &'a [NodeProto],\n    seg_outputs: &'a HashSet<String>,\n    seg_inputs_set: &'a HashSet<String>,\n    future_inputs: &'a HashSet<String>,\n}\n\nstruct ShapeContext<'a> {\n    graph: &'a GraphProto,\n    init_map: &'a HashMap<&'a str, &'a TensorProto>,\n    vi_map: &'a HashMap<String, &'a ValueInfoProto>,\n    traced_shapes: &'a HashMap<String, Vec<i64>>,\n    traced_types: &'a HashMap<String, i32>,\n    init_types: &'a HashMap<&'a str, i32>,\n    node_output_types: &'a HashMap<String, i32>,\n    constant_producers: &'a HashMap<String, &'a TensorProto>,\n}\n\nimpl ShapeContext<'_> {\n    fn resolve_elem_type(&self, name: &str) -> i32 {\n        // Resolution order: parent value_info dtype is implicitly used\n        // upstream via vi_map; fall back to initializer dtype, then to\n        // dtype-aware tract trace, then to the small set of ops whose\n        // output dtype is fixed by the spec, and finally to FLOAT.\n        // Skipping the traced_types lookup is the bug that turned\n        // INT64 indices (TopK, Tile-of-int, Slice-of-int) into FLOAT\n        // value_info entries; the witness path then quantised them by\n        // alpha and produced indices like 138_149_888 = 1054 * 2^17.\n        self.init_types\n            .get(name)\n            .copied()\n            .or_else(|| self.traced_types.get(name).copied())\n            .or_else(|| self.node_output_types.get(name).copied())\n            .unwrap_or(TensorProto::FLOAT)\n    }\n}\n\nfn get_segment_details(\n    query: &SegmentQuery<'_>,\n    ctx: &ShapeContext<'_>,\n) -> Result<(Vec<ValueInfoProto>, Vec<ValueInfoProto>, Vec<TensorProto>)> {\n    let mut inputs = Vec::new();\n    let mut outputs = Vec::new();\n    let mut initializers = Vec::new();\n\n    let model_output_names: HashSet<String> =\n        ctx.graph.output.iter().map(|o| o.name.clone()).collect();\n\n    let mut added_inputs: HashSet<String> = HashSet::new();\n    let mut sorted_inputs: Vec<_> = query.seg_inputs_set.iter().collect();\n    sorted_inputs.sort();\n    for inp_name in sorted_inputs {\n        if query.seg_outputs.contains(inp_name) {\n            continue;\n        }\n        if ctx.init_map.contains_key(inp_name.as_str()) {\n            initializers.push((*ctx.init_map[inp_name.as_str()]).clone());\n        } else if ctx.constant_producers.contains_key(inp_name) {\n            let mut tensor = ctx.constant_producers[inp_name].clone();\n            tensor.name = inp_name.clone();\n            initializers.push(tensor);\n        } else if !added_inputs.contains(inp_name) {\n            if let Some(vi) = ctx.vi_map.get(inp_name) {\n                inputs.push((*vi).clone());\n            } else {\n                let shape = ctx\n                    .traced_shapes\n                    .get(inp_name)\n                    .cloned()\n                    .or_else(|| resolve_shape_backward(inp_name, ctx.graph, ctx.traced_shapes))\n                    .ok_or_else(|| {\n                        DsperseError::Slicer(format!(\n                            \"no traced shape for segment input tensor '{inp_name}'\"\n                        ))\n                    })?;\n                inputs.push(onnx_proto::make_tensor_value_info(\n                    inp_name,\n                    ctx.resolve_elem_type(inp_name),\n                    &shape,\n                ));\n            }\n            added_inputs.insert(inp_name.clone());\n        }\n    }\n\n    let mut sorted_outputs: Vec<_> = query.seg_outputs.iter().collect();\n    sorted_outputs.sort();\n    for out_name in sorted_outputs {\n        if ctx.constant_producers.contains_key(out_name) {\n            continue;\n        }\n        let consumed_internally = query.nodes.iter().any(|n| n.input.contains(out_name));\n        let needed_externally =\n            query.future_inputs.contains(out_name) || model_output_names.contains(out_name);\n\n        if !consumed_internally || needed_externally {\n            if let Some(vi) = ctx.vi_map.get(out_name) {\n                outputs.push((*vi).clone());\n            } else {\n                let shape = ctx\n                    .traced_shapes\n                    .get(out_name)\n                    .cloned()\n                    .or_else(|| resolve_shape_backward(out_name, ctx.graph, ctx.traced_shapes))\n                    .ok_or_else(|| {\n                        DsperseError::Slicer(format!(\n                            \"no traced shape for segment output tensor '{out_name}'\"\n                        ))\n                    })?;\n                outputs.push(onnx_proto::make_tensor_value_info(\n                    out_name,\n                    ctx.resolve_elem_type(out_name),\n                    &shape,\n                ));\n            }\n        }\n    }\n\n    Ok((inputs, outputs, initializers))\n}\n\npub fn build_node_output_types(graph: &GraphProto) -> HashMap<String, i32> {\n    // Propagate ONNX output dtypes in topological order so that\n    // every node's output dtype is derivable from already-resolved\n    // input dtypes.  This is the source of truth for the slicer\n    // when tract's runtime trace falls back to f32 because it can't\n    // statically evaluate a node (e.g. TopK with a runtime K, or\n    // any node downstream of one that taints during tract's\n    // best-effort eval).  Without this, INT64 indices flowing\n    // through Tile / Slice / Reshape / GatherElements get tagged\n    // as FLOAT in slice value_info and the witness path quantises\n    // them by alpha, producing nonsense indices like\n    // 1054 * 2^17 = 138_149_888.\n    let mut types: HashMap<String, i32> = HashMap::new();\n    for init in &graph.initializer {\n        if init.data_type != 0 {\n            types.insert(init.name.clone(), init.data_type);\n        }\n    }\n    for vi in graph\n        .input\n        .iter()\n        .chain(graph.value_info.iter())\n        .chain(graph.output.iter())\n    {\n        if let Some(dt) = onnx_proto::elem_type_from_value_info(vi)\n            && dt != 0\n            && !types.contains_key(&vi.name)\n        {\n            types.insert(vi.name.clone(), dt);\n        }\n    }\n    let pass_through_first_input: &[&str] = &[\n        \"Tile\",\n        \"Slice\",\n        \"Reshape\",\n        \"Transpose\",\n        \"Squeeze\",\n        \"Unsqueeze\",\n        \"Identity\",\n        \"Flatten\",\n        \"Expand\",\n        \"Concat\",\n        \"Gather\",\n        \"GatherElements\",\n        \"GatherND\",\n        \"Pad\",\n        \"Compress\",\n        \"ScatterND\",\n        \"ScatterElements\",\n        \"Scatter\",\n        \"Split\",\n        \"DepthToSpace\",\n        \"SpaceToDepth\",\n        \"ReverseSequence\",\n        \"OneHot\",\n        \"Resize\",\n        \"Upsample\",\n        \"Crop\",\n        // Unique preserves input dtype on its first output (output[0]\n        // values; output[1..3] indices/inverse/counts default to\n        // INT64 but are not standard pass-through targets and the\n        // slicer rarely sees them as graph-internal value_info\n        // entries -- defer per-output handling until we encounter a\n        // real model that needs it).\n        \"Unique\",\n    ];\n    let always_int64: &[&str] = &[\n        \"Shape\",\n        \"NonZero\",\n        \"ArgMax\",\n        \"ArgMin\",\n        // NonMaxSuppression's selected_indices output is INT64.\n        \"NonMaxSuppression\",\n    ];\n    let always_bool: &[&str] = &[\n        \"Equal\",\n        \"Less\",\n        \"LessOrEqual\",\n        \"Greater\",\n        \"GreaterOrEqual\",\n        \"And\",\n        \"Or\",\n        \"Not\",\n        \"Xor\",\n        \"IsNaN\",\n        \"IsInf\",\n    ];\n\n    for node in &graph.node {\n        match node.op_type.as_str() {\n            \"Cast\" => {\n                if let Some(to) = onnx_proto::get_attribute_int(node, \"to\") {\n                    for out in &node.output {\n                        if !out.is_empty() {\n                            types.insert(out.clone(), to as i32);\n                        }\n                    }\n                }\n            }\n            \"Constant\" => {\n                if let Some(t) = node\n                    .attribute\n                    .iter()\n                    .find(|a| a.name == \"value\")\n                    .and_then(|a| a.t.as_ref())\n                    && let Some(out) = node.output.first()\n                    && !out.is_empty()\n                {\n                    types.insert(out.clone(), t.data_type);\n                }\n            }\n            \"ConstantOfShape\" => {\n                let dt = node\n                    .attribute\n                    .iter()\n                    .find(|a| a.name == \"value\")\n                    .and_then(|a| a.t.as_ref())\n                    .map(|t| t.data_type)\n                    .unwrap_or(TensorProto::FLOAT);\n                for out in &node.output {\n                    if !out.is_empty() {\n                        types.insert(out.clone(), dt);\n                    }\n                }\n            }\n            \"MaxPool\" => {\n                if let Some(idx_out) = node.output.get(1)\n                    && !idx_out.is_empty()\n                {\n                    types.insert(idx_out.clone(), TensorProto::INT64);\n                }\n                if let Some(val_out) = node.output.first()\n                    && !val_out.is_empty()\n                    && let Some(in_name) = node.input.first()\n                    && let Some(&dt) = types.get(in_name.as_str())\n                {\n                    types.insert(val_out.clone(), dt);\n                }\n            }\n            \"TopK\" => {\n                if let Some(val_out) = node.output.first()\n                    && !val_out.is_empty()\n                    && let Some(in_name) = node.input.first()\n                    && let Some(&dt) = types.get(in_name.as_str())\n                {\n                    types.insert(val_out.clone(), dt);\n                }\n                if let Some(idx_out) = node.output.get(1)\n                    && !idx_out.is_empty()\n                {\n                    types.insert(idx_out.clone(), TensorProto::INT64);\n                }\n            }\n            \"Where\" => {\n                if let Some(out) = node.output.first()\n                    && !out.is_empty()\n                {\n                    if let Some(&dt) = node.input.get(1).and_then(|n| types.get(n.as_str())) {\n                        types.insert(out.clone(), dt);\n                    } else if let Some(&dt) = node.input.get(2).and_then(|n| types.get(n.as_str()))\n                    {\n                        types.insert(out.clone(), dt);\n                    }\n                }\n            }\n            op if always_int64.contains(&op) => {\n                for out in &node.output {\n                    if !out.is_empty() {\n                        types.insert(out.clone(), TensorProto::INT64);\n                    }\n                }\n            }\n            op if always_bool.contains(&op) => {\n                for out in &node.output {\n                    if !out.is_empty() {\n                        types.insert(out.clone(), TensorProto::BOOL);\n                    }\n                }\n            }\n            op if pass_through_first_input.contains(&op) => {\n                if let Some(in_name) = node.input.first().filter(|s| !s.is_empty())\n                    && let Some(&dt) = types.get(in_name.as_str())\n                {\n                    for out in &node.output {\n                        if !out.is_empty() {\n                            types.insert(out.clone(), dt);\n                        }\n                    }\n                }\n            }\n            _ => {}\n        }\n    }\n    types\n}\n\nfn extract_dslice_archive(archive: &Path, dest: &Path) -> Result<()> {\n    let tmp_dir = dest.with_file_name(format!(\n        \".{}.extracting.{}\",\n        dest.file_name().unwrap_or_default().to_string_lossy(),\n        std::process::id()\n    ));\n    std::fs::create_dir_all(&tmp_dir).map_err(|e| DsperseError::io(e, &tmp_dir))?;\n    let file = std::fs::File::open(archive).map_err(|e| DsperseError::io(e, archive))?;\n    let mut zip = zip::ZipArchive::new(file).map_err(|e| {\n        DsperseError::Slicer(format!(\"reading dslice archive {}: {e}\", archive.display()))\n    })?;\n    if let Err(e) = zip.extract(&tmp_dir) {\n        std::fs::remove_dir_all(&tmp_dir).ok();\n        return Err(DsperseError::Slicer(format!(\n            \"extracting {} to {}: {e}\",\n            archive.display(),\n            tmp_dir.display()\n        )));\n    }\n    if let Err(e) = std::fs::rename(&tmp_dir, dest) {\n        std::fs::remove_dir_all(&tmp_dir).ok();\n        if dest.exists() {\n            return Ok(());\n        }\n        return Err(DsperseError::Slicer(format!(\n            \"renaming {} to {}: {e}\",\n            tmp_dir.display(),\n            dest.display()\n        )));\n    }\n    tracing::debug!(archive = %archive.display(), dest = %dest.display(), \"extracted dslice archive\");\n    Ok(())\n}\n\npub fn cleanup_extracted_slice(slices_dir: &Path, slice_id: &str) {\n    let extract_dir = slices_dir.join(slice_id);\n    if std::fs::remove_dir_all(&extract_dir).is_err() && extract_dir.exists() {\n        tracing::warn!(dir = %extract_dir.display(), \"failed to remove extracted slice dir\");\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/mod.rs",
    "content": "pub mod analyzer;\npub mod autotiler;\npub mod combiner;\npub(crate) mod layernorm_fuse;\npub mod materializer;\npub(crate) mod onnx_fold;\npub mod onnx_proto;\npub(crate) mod onnx_shapes;\npub mod onnx_slicer;\npub(crate) mod self_div_rewrite;\npub(crate) mod trace;\n\npub use onnx_slicer::slice_model;\n\npub(crate) const UNARY_ACTIVATIONS: &[&str] = &[\n    \"Relu\",\n    \"LeakyRelu\",\n    \"PRelu\",\n    \"Sigmoid\",\n    \"Tanh\",\n    \"Clip\",\n    \"Neg\",\n    \"Abs\",\n    \"Sqrt\",\n    \"Exp\",\n    \"Log\",\n    \"Sin\",\n    \"Cos\",\n    \"Erf\",\n];\n\npub(crate) const UNARY_STRUCTURAL: &[&str] = &[\"Cast\", \"Not\", \"Identity\", \"Dropout\"];\n\npub(crate) const BINARY_ARITHMETIC: &[&str] = &[\"Add\", \"Sub\", \"Mul\", \"Div\", \"Pow\", \"Max\", \"Min\"];\n\npub(crate) const NORMALIZATION_OPS: &[&str] =\n    &[\"BatchNormalization\", \"Softmax\", \"LayerNormalization\"];\n\npub(crate) const LAYOUT_OPS: &[&str] = &[\n    \"Reshape\",\n    \"Transpose\",\n    \"Flatten\",\n    \"Squeeze\",\n    \"Unsqueeze\",\n    \"Gather\",\n];\n\npub(crate) const CONTROL_FLOW_OPS: &[&str] = &[\"Loop\", \"If\", \"Scan\"];\n\npub(crate) fn is_control_flow(op: &str) -> bool {\n    CONTROL_FLOW_OPS.contains(&op)\n}\n\npub(crate) fn collect_subgraph_outer_refs(\n    node: &onnx_proto::NodeProto,\n    graph: &onnx_proto::GraphProto,\n) -> Vec<String> {\n    let mut outer_refs = Vec::new();\n    for attr in &node.attribute {\n        let subgraphs: Vec<&onnx_proto::GraphProto> =\n            attr.g.iter().chain(attr.graphs.iter()).collect();\n        for sg in subgraphs {\n            collect_outer_refs_recursive(sg, graph, &mut outer_refs);\n        }\n    }\n    outer_refs.sort();\n    outer_refs.dedup();\n    outer_refs\n}\n\nfn collect_outer_refs_recursive(\n    subgraph: &onnx_proto::GraphProto,\n    outer_graph: &onnx_proto::GraphProto,\n    outer_refs: &mut Vec<String>,\n) {\n    let local_names: std::collections::HashSet<String> = subgraph\n        .input\n        .iter()\n        .map(|vi| vi.name.clone())\n        .chain(subgraph.initializer.iter().map(|i| i.name.clone()))\n        .chain(subgraph.node.iter().flat_map(|n| n.output.iter().cloned()))\n        .collect();\n\n    let outer_names: std::collections::HashSet<&str> = outer_graph\n        .input\n        .iter()\n        .map(|vi| vi.name.as_str())\n        .chain(outer_graph.initializer.iter().map(|i| i.name.as_str()))\n        .chain(\n            outer_graph\n                .node\n                .iter()\n                .flat_map(|n| n.output.iter().map(|s| s.as_str())),\n        )\n        .chain(outer_graph.value_info.iter().map(|vi| vi.name.as_str()))\n        .collect();\n\n    for sg_node in &subgraph.node {\n        for inp in &sg_node.input {\n            if !inp.is_empty() && !local_names.contains(inp) && outer_names.contains(inp.as_str()) {\n                outer_refs.push(inp.clone());\n            }\n        }\n        for attr in &sg_node.attribute {\n            let nested: Vec<&onnx_proto::GraphProto> =\n                attr.g.iter().chain(attr.graphs.iter()).collect();\n            for nested_sg in nested {\n                collect_outer_refs_recursive(nested_sg, outer_graph, outer_refs);\n            }\n        }\n    }\n}\n\npub(crate) fn is_shape_preserving(op: &str) -> bool {\n    UNARY_ACTIVATIONS.contains(&op)\n        || UNARY_STRUCTURAL.contains(&op)\n        || NORMALIZATION_OPS.contains(&op)\n}\n\n/// Ops the slicer may absorb into an adjacent activation slice\n/// without creating a new compile boundary.  This is a superset of\n/// `is_shape_preserving`: it additionally covers the layout ops\n/// (Reshape / Transpose / Flatten / Squeeze / Unsqueeze / Gather)\n/// which CHANGE the tensor shape but do not introduce heavy\n/// compute -- grouping them with the producer keeps transformer\n/// reshape-transpose chains from shattering into N one-op slices.\n///\n/// Do NOT reuse this for shape-fallback decisions:\n/// `is_shape_preserving` is consumed by the trace / materializer\n/// to assume `output_shape == input_shape`, which is FALSE for\n/// every op in LAYOUT_OPS by definition.  Keep that predicate\n/// strict and route slicer-grouping checks through this function\n/// instead.\npub(crate) fn is_slice_passthrough(op: &str) -> bool {\n    is_shape_preserving(op) || LAYOUT_OPS.contains(&op)\n}\n\npub(crate) fn is_elementwise(op: &str) -> bool {\n    UNARY_ACTIVATIONS.contains(&op) || BINARY_ARITHMETIC.contains(&op)\n}\n\npub(crate) fn is_binary_arithmetic(op: &str) -> bool {\n    BINARY_ARITHMETIC.contains(&op)\n}\n\npub(crate) fn build_segment_ranges(\n    slice_points: &[usize],\n    total_nodes: Option<usize>,\n) -> Vec<(usize, usize)> {\n    let mut points = slice_points.to_vec();\n    if let Some(total) = total_nodes\n        && !points.contains(&total)\n    {\n        points.push(total);\n    }\n    points.sort();\n    points.dedup();\n\n    let mut ranges = Vec::new();\n    for i in 0..points.len() {\n        let start = if i > 0 { points[i - 1] } else { 0 };\n        let end = points[i];\n        if start < end {\n            ranges.push((start, end));\n        }\n    }\n    ranges\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/onnx_fold.rs",
    "content": "use std::collections::{HashMap, HashSet};\n\nuse super::onnx_proto::{\n    GraphProto, ModelProto, NodeProto, TensorProto, tensor_to_f32, tensor_to_i64,\n};\n\npub fn fold_constant_nodes(model: &mut ModelProto) -> HashSet<String> {\n    let graph = match model.graph.as_mut() {\n        Some(g) => g,\n        None => return HashSet::new(),\n    };\n\n    let mut folded_tensors: Vec<TensorProto> = Vec::new();\n    let mut folded_names: HashSet<String> = HashSet::new();\n\n    for node in &graph.node {\n        if node.op_type != \"Constant\" {\n            continue;\n        }\n        let out_name = match node.output.first() {\n            Some(n) if !n.is_empty() => n,\n            _ => continue,\n        };\n        let tensor = match node.attribute.iter().find(|a| a.name == \"value\") {\n            Some(a) => match a.t.as_ref() {\n                Some(t) => t,\n                None => continue,\n            },\n            None => continue,\n        };\n        let mut t = tensor.clone();\n        t.name = out_name.clone();\n        folded_tensors.push(t);\n        folded_names.insert(out_name.clone());\n    }\n\n    if folded_names.is_empty() {\n        return folded_names;\n    }\n\n    graph\n        .node\n        .retain(|n| n.op_type != \"Constant\" || !n.output.iter().any(|o| folded_names.contains(o)));\n\n    let count = folded_tensors.len();\n    graph.initializer.extend(folded_tensors);\n\n    tracing::info!(count, \"folded Constant ops into initializers\");\n\n    let propagated_names = propagate_constants(graph);\n    if !propagated_names.is_empty() {\n        tracing::info!(\n            propagated = propagated_names.len(),\n            \"propagated constants after Constant-node folding\"\n        );\n    }\n    folded_names.extend(propagated_names);\n\n    // Graph simplification runs before Conv+BN fusion so that any\n    // Identity chain sitting between a Conv and a BatchNormalization\n    // collapses first, exposing a contiguous Conv -> BN pattern to\n    // the fusion pass.\n    let identity_count = remove_identity_nodes(graph);\n    if identity_count > 0 {\n        tracing::info!(identity_count, \"removed Identity nodes\");\n    }\n\n    let dead_count = eliminate_dead_nodes(graph);\n    if dead_count > 0 {\n        tracing::info!(dead_count, \"eliminated dead nodes\");\n    }\n\n    let fused = fuse_conv_batchnorm(graph);\n    if fused > 0 {\n        tracing::info!(fused, \"fused Conv+BatchNormalization pairs\");\n    }\n\n    folded_names\n}\n\npub fn remove_identity_nodes(graph: &mut GraphProto) -> usize {\n    let identity_map: HashMap<String, String> = graph\n        .node\n        .iter()\n        .filter(|n| n.op_type == \"Identity\" && n.input.len() == 1 && n.output.len() == 1)\n        .filter(|n| !n.input[0].is_empty() && !n.output[0].is_empty())\n        .map(|n| (n.output[0].clone(), n.input[0].clone()))\n        .collect();\n\n    if identity_map.is_empty() {\n        return 0;\n    }\n\n    fn resolve(name: &str, map: &HashMap<String, String>) -> String {\n        let mut current = name;\n        let mut visited = HashSet::new();\n        while let Some(target) = map.get(current) {\n            if !visited.insert(current) {\n                break;\n            }\n            current = target;\n        }\n        current.to_string()\n    }\n\n    let output_names: HashSet<String> = graph.output.iter().map(|o| o.name.clone()).collect();\n\n    // Only rewire consumers whose Identity output is NOT an exported\n    // graph output.  Exported names are the model's public interface\n    // and must survive as-is; we preserve those Identity nodes\n    // instead of renaming the graph output.  Rewriting graph.output\n    // in place would silently change the model's API and let DCE\n    // below remove the Identity that produces the exported tensor.\n    let drop_map: HashMap<String, String> = identity_map\n        .iter()\n        .filter(|(out, _)| !output_names.contains(out.as_str()))\n        .map(|(out, inp)| (out.clone(), inp.clone()))\n        .collect();\n\n    for node in &mut graph.node {\n        // Skip the node that produced this drop-map entry so we\n        // don't rewrite its own input to its own output.  Guard\n        // the output-slot access: the drop_map construction only\n        // accepts len-1 Identity nodes, but a malformed Identity\n        // with zero outputs could still appear in graph.node and\n        // must not trip an index panic here.\n        let is_dropped_identity = node.op_type == \"Identity\"\n            && node\n                .output\n                .first()\n                .is_some_and(|o| drop_map.contains_key(o.as_str()));\n        if is_dropped_identity {\n            continue;\n        }\n        for inp in &mut node.input {\n            if drop_map.contains_key(inp.as_str()) {\n                *inp = resolve(inp, &drop_map);\n            }\n        }\n    }\n\n    let count = drop_map.len();\n    graph.node.retain(|n| {\n        !(n.op_type == \"Identity\" && n.output.len() == 1 && drop_map.contains_key(&n.output[0]))\n    });\n    count\n}\n\npub fn eliminate_dead_nodes(graph: &mut GraphProto) -> usize {\n    let output_names: HashSet<String> = graph.output.iter().map(|o| o.name.clone()).collect();\n\n    let mut consumed: HashSet<String> = output_names;\n    let mut changed = true;\n    while changed {\n        changed = false;\n        for node in &graph.node {\n            let produces_consumed = node.output.iter().any(|o| consumed.contains(o));\n            if produces_consumed {\n                for inp in &node.input {\n                    if !inp.is_empty() && consumed.insert(inp.clone()) {\n                        changed = true;\n                    }\n                }\n            }\n        }\n    }\n\n    let before = graph.node.len();\n    graph\n        .node\n        .retain(|n| n.output.iter().any(|o| consumed.contains(o)));\n    let removed = before - graph.node.len();\n\n    if removed > 0 {\n        graph.initializer.retain(|i| consumed.contains(&i.name));\n        graph.value_info.retain(|vi| consumed.contains(&vi.name));\n    }\n\n    removed\n}\n\npub fn propagate_constants_with_shapes(\n    graph: &mut GraphProto,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n) -> usize {\n    for node in &graph.node {\n        if node.op_type == \"Shape\"\n            && let Some(inp_name) = node.input.first()\n            && let Some(full_shape) = traced_shapes.get(inp_name)\n            && let Some(out_name) = node.output.first()\n            && !out_name.is_empty()\n            && !graph.initializer.iter().any(|i| i.name == *out_name)\n        {\n            let ndim = full_shape.len() as i64;\n            let start_attr = node\n                .attribute\n                .iter()\n                .find(|a| a.name == \"start\")\n                .map(|a| a.i)\n                .unwrap_or(0);\n            let end_attr = node\n                .attribute\n                .iter()\n                .find(|a| a.name == \"end\")\n                .map(|a| a.i)\n                .unwrap_or(ndim);\n            let start = if start_attr < 0 {\n                (ndim + start_attr).max(0) as usize\n            } else {\n                (start_attr as usize).min(full_shape.len())\n            };\n            let end = if end_attr < 0 {\n                (ndim + end_attr).max(0) as usize\n            } else {\n                (end_attr as usize).min(full_shape.len())\n            };\n            let sliced: Vec<i64> = if start < end {\n                full_shape[start..end].to_vec()\n            } else {\n                vec![]\n            };\n            graph.initializer.push(TensorProto {\n                name: out_name.clone(),\n                data_type: TensorProto::INT64,\n                dims: vec![sliced.len() as i64],\n                int64_data: sliced,\n                ..Default::default()\n            });\n        }\n    }\n    let init_names: HashSet<String> = graph.initializer.iter().map(|i| i.name.clone()).collect();\n    graph\n        .node\n        .retain(|n| n.op_type != \"Shape\" || !n.output.iter().any(|o| init_names.contains(o)));\n    let folded = propagate_constants(graph);\n    folded.len()\n}\n\npub(crate) fn propagate_constants(graph: &mut GraphProto) -> HashSet<String> {\n    let mut constants: HashMap<String, TensorProto> = graph\n        .initializer\n        .iter()\n        .map(|t| (t.name.clone(), t.clone()))\n        .collect();\n\n    let mut folded_node_indices: HashSet<usize> = HashSet::new();\n\n    loop {\n        let mut progress = false;\n        for (idx, node) in graph.node.iter().enumerate() {\n            if folded_node_indices.contains(&idx) {\n                continue;\n            }\n            let inputs: Vec<&str> = node\n                .input\n                .iter()\n                .filter(|s| !s.is_empty())\n                .map(String::as_str)\n                .collect();\n            if inputs.is_empty() {\n                continue;\n            }\n            if !inputs.iter().all(|name| constants.contains_key(*name)) {\n                continue;\n            }\n            let input_tensors: Vec<&TensorProto> = inputs.iter().map(|n| &constants[*n]).collect();\n            if let Some(outputs) = eval_const_node(node, &input_tensors) {\n                for (out_name, tensor) in outputs {\n                    constants.insert(out_name, tensor);\n                }\n                folded_node_indices.insert(idx);\n                progress = true;\n            }\n        }\n        if !progress {\n            break;\n        }\n    }\n\n    if folded_node_indices.is_empty() {\n        return HashSet::new();\n    }\n\n    let mut new_init_names: HashSet<String> = HashSet::new();\n    for idx in &folded_node_indices {\n        for out in &graph.node[*idx].output {\n            if !out.is_empty() && constants.contains_key(out) {\n                new_init_names.insert(out.clone());\n            }\n        }\n    }\n\n    let mut consumed_by_remaining: HashSet<String> = graph\n        .node\n        .iter()\n        .enumerate()\n        .filter(|(i, _)| !folded_node_indices.contains(i))\n        .flat_map(|(_, n)| n.input.iter().cloned())\n        .collect();\n    for node in &graph.node {\n        if super::is_control_flow(&node.op_type) {\n            let outer_refs = super::collect_subgraph_outer_refs(node, graph);\n            consumed_by_remaining.extend(outer_refs);\n        }\n    }\n    let output_names: HashSet<String> = graph.output.iter().map(|o| o.name.clone()).collect();\n\n    for name in &new_init_names {\n        if (consumed_by_remaining.contains(name) || output_names.contains(name))\n            && let Some(t) = constants.get(name)\n            && !graph.initializer.iter().any(|i| i.name == *name)\n        {\n            graph.initializer.push(t.clone());\n        }\n    }\n\n    let removed_outputs: HashSet<String> = folded_node_indices\n        .iter()\n        .flat_map(|idx| graph.node[*idx].output.iter().cloned())\n        .collect();\n    graph\n        .input\n        .retain(|vi| !removed_outputs.contains(&vi.name) || output_names.contains(&vi.name));\n\n    let count = folded_node_indices.len();\n    let mut kept = Vec::with_capacity(graph.node.len() - count);\n    for (idx, node) in graph.node.drain(..).enumerate() {\n        if !folded_node_indices.contains(&idx) {\n            kept.push(node);\n        }\n    }\n    graph.node = kept;\n\n    tracing::info!(count, \"propagated constant subgraphs into initializers\");\n    new_init_names\n}\n\nfn eval_const_node(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n) -> Option<Vec<(String, TensorProto)>> {\n    let out_name = node.output.first()?.clone();\n    if out_name.is_empty() {\n        return None;\n    }\n    match node.op_type.as_str() {\n        \"Identity\" => {\n            let mut t = inputs[0].clone();\n            t.name = out_name.clone();\n            Some(vec![(out_name, t)])\n        }\n        \"Cast\" => eval_cast(node, inputs[0], &out_name),\n        \"Sqrt\" => eval_unary_f32(inputs[0], &out_name, f32::sqrt),\n        \"Neg\" => eval_unary_f32(inputs[0], &out_name, |x| -x),\n        \"Abs\" => eval_unary_f32(inputs[0], &out_name, f32::abs),\n        \"Exp\" => eval_unary_f32(inputs[0], &out_name, f32::exp),\n        \"Log\" => eval_unary_f32(inputs[0], &out_name, f32::ln),\n        \"Ceil\" => eval_unary_f32(inputs[0], &out_name, f32::ceil),\n        \"Floor\" => eval_unary_f32(inputs[0], &out_name, f32::floor),\n        \"Reciprocal\" => eval_unary_f32(inputs[0], &out_name, |x| 1.0 / x),\n        \"Relu\" => eval_unary_f32(inputs[0], &out_name, |x| x.max(0.0)),\n        \"Sigmoid\" => eval_unary_f32(inputs[0], &out_name, |x| 1.0 / (1.0 + (-x).exp())),\n        \"Tanh\" => eval_unary_f32(inputs[0], &out_name, f32::tanh),\n        \"Add\" => eval_binary_f32(inputs, &out_name, |a, b| a + b),\n        \"Sub\" => eval_binary_f32(inputs, &out_name, |a, b| a - b),\n        \"Mul\" => eval_binary_f32(inputs, &out_name, |a, b| a * b),\n        \"Div\" => eval_binary_f32(inputs, &out_name, |a, b| a / b),\n        \"Pow\" => eval_binary_f32(inputs, &out_name, f32::powf),\n        \"Reshape\" => eval_reshape(node, inputs, &out_name),\n        \"Squeeze\" => eval_squeeze(node, inputs, &out_name),\n        \"Unsqueeze\" => eval_unsqueeze(node, inputs, &out_name),\n        \"Shape\" => eval_shape(node, inputs[0], &out_name),\n        \"Gather\" if inputs.len() >= 2 => eval_gather(node, inputs, &out_name),\n        \"Slice\" if inputs.len() >= 3 => eval_slice(inputs, &out_name),\n        \"Concat\" => eval_concat(node, inputs, &out_name),\n        \"ConstantOfShape\" => eval_constant_of_shape(node, inputs[0], &out_name),\n        \"Where\" if inputs.len() == 3 => eval_where(inputs, &out_name),\n        \"Range\" if inputs.len() == 3 => eval_range(inputs, &out_name),\n        \"Equal\" => eval_cmp(inputs, &out_name, |a, b| a == b, |a, b| a == b),\n        \"Less\" => eval_cmp(inputs, &out_name, |a, b| a < b, |a, b| a < b),\n        \"Greater\" => eval_cmp(inputs, &out_name, |a, b| a > b, |a, b| a > b),\n        \"Not\" => eval_not(inputs[0], &out_name),\n        \"And\" => eval_logical(inputs, &out_name, |a, b| a & b),\n        \"Or\" => eval_logical(inputs, &out_name, |a, b| a | b),\n        \"Transpose\" => eval_transpose(node, inputs[0], &out_name),\n        \"ReduceMean\" => eval_reduce(node, inputs, &out_name, ReduceOp::Mean),\n        \"ReduceSum\" => eval_reduce(node, inputs, &out_name, ReduceOp::Sum),\n        \"ReduceMax\" => eval_reduce(node, inputs, &out_name, ReduceOp::Max),\n        \"ReduceMin\" => eval_reduce(node, inputs, &out_name, ReduceOp::Min),\n        \"Resize\" => eval_resize(node, inputs, &out_name),\n        \"Expand\" if inputs.len() == 2 => eval_expand(inputs, &out_name),\n        \"Tile\" if inputs.len() == 2 => eval_tile(inputs, &out_name),\n        \"ScatterND\" if inputs.len() == 3 => eval_scatter_nd(inputs, &out_name),\n        \"Split\" => eval_split(node, inputs, &node.output),\n        _ => None,\n    }\n}\n\nfn eval_expand(inputs: &[&TensorProto], out_name: &str) -> Option<Vec<(String, TensorProto)>> {\n    let data = inputs[0];\n    let shape = tensor_to_i64(inputs[1]);\n    if shape.is_empty() {\n        return None;\n    }\n    let out_dims = broadcast_shape(&data.dims, &shape)?;\n    let total = broadcast_total(&out_dims)?;\n    if data.data_type == TensorProto::INT64 {\n        let v = tensor_to_i64(data);\n        if v.is_empty() {\n            return None;\n        }\n        let mut result = Vec::with_capacity(total);\n        for i in 0..total {\n            let di = broadcast_index(i, &out_dims, &data.dims);\n            result.push(v[di]);\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: out_dims,\n            int64_data: result,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    let v = tensor_to_f32(data);\n    if v.is_empty() {\n        return None;\n    }\n    let mut result = Vec::with_capacity(total);\n    for i in 0..total {\n        let di = broadcast_index(i, &out_dims, &data.dims);\n        result.push(v[di]);\n    }\n    let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_tile(inputs: &[&TensorProto], out_name: &str) -> Option<Vec<(String, TensorProto)>> {\n    let data = inputs[0];\n    let repeats = tensor_to_i64(inputs[1]);\n    if repeats.is_empty() || repeats.len() != data.dims.len() {\n        return None;\n    }\n    let rank = data.dims.len();\n    let out_dims: Vec<i64> = data\n        .dims\n        .iter()\n        .zip(&repeats)\n        .map(|(&d, &r)| d * r)\n        .collect();\n    let total = broadcast_total(&out_dims)?;\n\n    let in_strides: Vec<usize> = {\n        let mut s = vec![1usize; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * data.dims[i + 1] as usize;\n        }\n        s\n    };\n    let out_strides: Vec<usize> = {\n        let mut s = vec![1usize; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * out_dims[i + 1] as usize;\n        }\n        s\n    };\n\n    if data.data_type == TensorProto::INT64 {\n        let v = tensor_to_i64(data);\n        if v.is_empty() {\n            return None;\n        }\n        let mut result = vec![0i64; total];\n        for (o, out_slot) in result.iter_mut().enumerate().take(total) {\n            let mut src = 0usize;\n            let mut rem = o;\n            for i in 0..rank {\n                let coord = rem / out_strides[i];\n                rem %= out_strides[i];\n                src += (coord % data.dims[i] as usize) * in_strides[i];\n            }\n            *out_slot = v[src];\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: out_dims,\n            int64_data: result,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    let v = tensor_to_f32(data);\n    if v.is_empty() {\n        return None;\n    }\n    let mut result = vec![0f32; total];\n    for (o, out_slot) in result.iter_mut().enumerate().take(total) {\n        let mut src = 0usize;\n        let mut rem = o;\n        for i in 0..rank {\n            let coord = rem / out_strides[i];\n            rem %= out_strides[i];\n            src += (coord % data.dims[i] as usize) * in_strides[i];\n        }\n        *out_slot = v[src];\n    }\n    let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_constant_of_shape(\n    node: &NodeProto,\n    shape_t: &TensorProto,\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let dims = tensor_to_i64(shape_t);\n    if dims.is_empty() {\n        return None;\n    }\n    let total = broadcast_total(&dims)?;\n    let (dtype, f_val, i_val) = match node.attribute.iter().find(|a| a.name == \"value\") {\n        Some(a) => match a.t.as_ref() {\n            Some(t) => {\n                let fv = tensor_to_f32(t).first().copied().unwrap_or(0.0);\n                let iv = tensor_to_i64(t).first().copied().unwrap_or(fv as i64);\n                (t.data_type, fv, iv)\n            }\n            None => (TensorProto::FLOAT, 0.0, 0),\n        },\n        None => (TensorProto::FLOAT, 0.0, 0),\n    };\n    let t = match dtype {\n        TensorProto::INT64 => TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: dims.clone(),\n            int64_data: vec![i_val; total],\n            ..Default::default()\n        },\n        TensorProto::INT32 => TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT32,\n            dims: dims.clone(),\n            int32_data: vec![i_val as i32; total],\n            ..Default::default()\n        },\n        TensorProto::BOOL => TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::BOOL,\n            dims: dims.clone(),\n            int32_data: vec![(i_val != 0) as i32; total],\n            ..Default::default()\n        },\n        _ => make_f32_tensor(out_name, &dims, &vec![f_val; total], dtype),\n    };\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_where(inputs: &[&TensorProto], out_name: &str) -> Option<Vec<(String, TensorProto)>> {\n    let cond = tensor_to_i64(inputs[0]);\n    let data_type =\n        if inputs[1].data_type == TensorProto::INT64 && inputs[2].data_type == TensorProto::INT64 {\n            TensorProto::INT64\n        } else {\n            TensorProto::FLOAT\n        };\n    if data_type == TensorProto::INT64 {\n        let x = tensor_to_i64(inputs[1]);\n        let y = tensor_to_i64(inputs[2]);\n        if x.is_empty() || y.is_empty() || cond.is_empty() {\n            return None;\n        }\n        let xy_dims = broadcast_shape(&inputs[1].dims, &inputs[2].dims)?;\n        let out_dims = broadcast_shape(&xy_dims, &inputs[0].dims)?;\n        let total = broadcast_total(&out_dims)?;\n        let mut result = Vec::with_capacity(total);\n        for i in 0..total {\n            let ci = broadcast_index(i, &out_dims, &inputs[0].dims);\n            let xi = broadcast_index(i, &out_dims, &inputs[1].dims);\n            let yi = broadcast_index(i, &out_dims, &inputs[2].dims);\n            result.push(if cond[ci] != 0 { x[xi] } else { y[yi] });\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: out_dims,\n            int64_data: result,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    let x = tensor_to_f32(inputs[1]);\n    let y = tensor_to_f32(inputs[2]);\n    if x.is_empty() || y.is_empty() || cond.is_empty() {\n        return None;\n    }\n    let xy_dims = broadcast_shape(&inputs[1].dims, &inputs[2].dims)?;\n    let out_dims = broadcast_shape(&xy_dims, &inputs[0].dims)?;\n    let total = broadcast_total(&out_dims)?;\n    let mut result = Vec::with_capacity(total);\n    for i in 0..total {\n        let ci = broadcast_index(i, &out_dims, &inputs[0].dims);\n        let xi = broadcast_index(i, &out_dims, &inputs[1].dims);\n        let yi = broadcast_index(i, &out_dims, &inputs[2].dims);\n        result.push(if cond[ci] != 0 { x[xi] } else { y[yi] });\n    }\n    let t = make_f32_tensor(out_name, &out_dims, &result, inputs[1].data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_range(inputs: &[&TensorProto], out_name: &str) -> Option<Vec<(String, TensorProto)>> {\n    let is_int = inputs[0].data_type == TensorProto::INT64\n        && inputs[1].data_type == TensorProto::INT64\n        && inputs[2].data_type == TensorProto::INT64;\n    if is_int {\n        let start = tensor_to_i64(inputs[0]).first().copied()?;\n        let limit = tensor_to_i64(inputs[1]).first().copied()?;\n        let delta = tensor_to_i64(inputs[2]).first().copied()?;\n        if delta == 0 {\n            return None;\n        }\n        let producing = (delta > 0 && start < limit) || (delta < 0 && start > limit);\n        let count = if producing {\n            let span = (limit - start) as i128;\n            let d = delta as i128;\n            let c = (span + d - d.signum()) / d;\n            usize::try_from(c).ok()?\n        } else {\n            0\n        };\n        if count > MAX_BROADCAST_ELEMENTS {\n            return None;\n        }\n        let mut out = Vec::with_capacity(count);\n        let mut v = start;\n        for _ in 0..count {\n            out.push(v);\n            v = v.checked_add(delta)?;\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: vec![out.len() as i64],\n            int64_data: out,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    let start = tensor_to_f32(inputs[0]).first().copied()?;\n    let limit = tensor_to_f32(inputs[1]).first().copied()?;\n    let delta = tensor_to_f32(inputs[2]).first().copied()?;\n    if delta == 0.0 || !delta.is_finite() || !start.is_finite() || !limit.is_finite() {\n        return None;\n    }\n    let count = ((limit - start) / delta).ceil();\n    if count <= 0.0 {\n        let dims = vec![0i64];\n        let t = make_f32_tensor(out_name, &dims, &[], inputs[0].data_type);\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    if count as usize > MAX_BROADCAST_ELEMENTS {\n        return None;\n    }\n    let count = count as usize;\n    let mut out = Vec::with_capacity(count);\n    let mut v = start;\n    for _ in 0..count {\n        if (delta > 0.0 && v >= limit) || (delta < 0.0 && v <= limit) {\n            break;\n        }\n        out.push(v);\n        v += delta;\n    }\n    let dims = vec![out.len() as i64];\n    let t = make_f32_tensor(out_name, &dims, &out, inputs[0].data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_cmp(\n    inputs: &[&TensorProto],\n    out_name: &str,\n    f_f32: fn(f32, f32) -> bool,\n    f_i64: fn(i64, i64) -> bool,\n) -> Option<Vec<(String, TensorProto)>> {\n    if inputs.len() < 2 {\n        return None;\n    }\n    let out_dims = broadcast_shape(&inputs[0].dims, &inputs[1].dims)?;\n    let total = broadcast_total(&out_dims)?;\n\n    let both_int =\n        inputs[0].data_type == TensorProto::INT64 && inputs[1].data_type == TensorProto::INT64;\n    let mut result = Vec::with_capacity(total);\n    if both_int {\n        let a = tensor_to_i64(inputs[0]);\n        let b = tensor_to_i64(inputs[1]);\n        if a.is_empty() || b.is_empty() {\n            return None;\n        }\n        for i in 0..total {\n            let ai = broadcast_index(i, &out_dims, &inputs[0].dims);\n            let bi = broadcast_index(i, &out_dims, &inputs[1].dims);\n            result.push(f_i64(a[ai], b[bi]) as i32);\n        }\n    } else {\n        let a = tensor_to_f32(inputs[0]);\n        let b = tensor_to_f32(inputs[1]);\n        if a.is_empty() || b.is_empty() {\n            return None;\n        }\n        for i in 0..total {\n            let ai = broadcast_index(i, &out_dims, &inputs[0].dims);\n            let bi = broadcast_index(i, &out_dims, &inputs[1].dims);\n            result.push(f_f32(a[ai], b[bi]) as i32);\n        }\n    }\n    let t = TensorProto {\n        name: out_name.to_string(),\n        data_type: TensorProto::BOOL,\n        dims: out_dims,\n        int32_data: result,\n        ..Default::default()\n    };\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_not(input: &TensorProto, out_name: &str) -> Option<Vec<(String, TensorProto)>> {\n    let vals = tensor_to_i64(input);\n    if vals.is_empty() {\n        return None;\n    }\n    let t = TensorProto {\n        name: out_name.to_string(),\n        data_type: TensorProto::BOOL,\n        dims: input.dims.clone(),\n        int32_data: vals.iter().map(|&v| (v == 0) as i32).collect(),\n        ..Default::default()\n    };\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_logical(\n    inputs: &[&TensorProto],\n    out_name: &str,\n    f: fn(i32, i32) -> i32,\n) -> Option<Vec<(String, TensorProto)>> {\n    if inputs.len() < 2 {\n        return None;\n    }\n    let a = tensor_to_i64(inputs[0]);\n    let b = tensor_to_i64(inputs[1]);\n    if a.is_empty() || b.is_empty() {\n        return None;\n    }\n    let out_dims = broadcast_shape(&inputs[0].dims, &inputs[1].dims)?;\n    let total = broadcast_total(&out_dims)?;\n    let mut result = Vec::with_capacity(total);\n    for i in 0..total {\n        let ai = broadcast_index(i, &out_dims, &inputs[0].dims);\n        let bi = broadcast_index(i, &out_dims, &inputs[1].dims);\n        result.push(f((a[ai] != 0) as i32, (b[bi] != 0) as i32));\n    }\n    let t = TensorProto {\n        name: out_name.to_string(),\n        data_type: TensorProto::BOOL,\n        dims: out_dims,\n        int32_data: result,\n        ..Default::default()\n    };\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_transpose(\n    node: &NodeProto,\n    input: &TensorProto,\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let rank = input.dims.len();\n    if rank == 0 {\n        return None;\n    }\n    let perm: Vec<usize> = match node.attribute.iter().find(|a| a.name == \"perm\") {\n        Some(attr) => {\n            if attr.ints.len() != rank {\n                return None;\n            }\n            let mut out = Vec::with_capacity(rank);\n            let mut seen = vec![false; rank];\n            for &raw in &attr.ints {\n                if raw < 0 || (raw as usize) >= rank {\n                    return None;\n                }\n                let p = raw as usize;\n                if seen[p] {\n                    return None;\n                }\n                seen[p] = true;\n                out.push(p);\n            }\n            out\n        }\n        None => (0..rank).rev().collect(),\n    };\n    let out_dims: Vec<i64> = perm.iter().map(|&p| input.dims[p]).collect();\n    let total = broadcast_total(&out_dims)?;\n\n    let src_strides = {\n        let mut s = vec![1i64; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * input.dims[i + 1];\n        }\n        s\n    };\n    let out_strides = {\n        let mut s = vec![1i64; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * out_dims[i + 1];\n        }\n        s\n    };\n\n    let permute_index = |out_linear: usize| -> usize {\n        let mut src = 0i64;\n        let mut rem = out_linear as i64;\n        for i in 0..rank {\n            let coord = rem / out_strides[i];\n            rem %= out_strides[i];\n            src += coord * src_strides[perm[i]];\n        }\n        src as usize\n    };\n\n    if input.data_type == TensorProto::INT64 {\n        let vals = tensor_to_i64(input);\n        if vals.is_empty() {\n            return None;\n        }\n        let mut result = Vec::with_capacity(total);\n        for i in 0..total {\n            result.push(vals[permute_index(i)]);\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: out_dims,\n            int64_data: result,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    let vals = tensor_to_f32(input);\n    if vals.is_empty() {\n        return None;\n    }\n    let mut result = Vec::with_capacity(total);\n    for i in 0..total {\n        result.push(vals[permute_index(i)]);\n    }\n    let t = make_f32_tensor(out_name, &out_dims, &result, input.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\n#[allow(clippy::too_many_lines)]\nfn eval_resize(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let named: Vec<(&str, Option<&TensorProto>)> = {\n        let mut it = inputs.iter().copied();\n        node.input\n            .iter()\n            .map(|name| {\n                let entry = if name.is_empty() { None } else { it.next() };\n                (name.as_str(), entry)\n            })\n            .collect()\n    };\n    let x = named.first().and_then(|(_, t)| *t)?;\n    if x.dims.len() < 2 {\n        return None;\n    }\n    let rank = x.dims.len();\n    let vals = tensor_to_f32(x);\n    if vals.is_empty() {\n        return None;\n    }\n\n    let mode = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"mode\")\n        .map(|a| std::str::from_utf8(&a.s).unwrap_or(\"\").to_string())\n        .unwrap_or_else(|| \"nearest\".to_string());\n    let ctm = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"coordinate_transformation_mode\")\n        .map(|a| std::str::from_utf8(&a.s).unwrap_or(\"\").to_string())\n        .unwrap_or_else(|| \"half_pixel\".to_string());\n    let cubic_a = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"cubic_coeff_a\")\n        .map(|a| a.f)\n        .unwrap_or(-0.75);\n    let exclude_outside = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"exclude_outside\")\n        .map(|a| a.i != 0)\n        .unwrap_or(false);\n    let extrapolation = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"extrapolation_value\")\n        .map(|a| a.f)\n        .unwrap_or(0.0);\n\n    let sizes_opt = named.get(3).and_then(|(_, t)| *t).and_then(|t| {\n        if t.dims.is_empty() || t.dims.iter().all(|&d| d == 0) {\n            None\n        } else {\n            let v = tensor_to_i64(t);\n            if v.len() == rank { Some(v) } else { None }\n        }\n    });\n    let scales_opt = named.get(2).and_then(|(_, t)| *t).and_then(|t| {\n        if t.dims.is_empty() || t.dims.iter().all(|&d| d == 0) {\n            None\n        } else {\n            let v = tensor_to_f32(t);\n            if v.len() == rank { Some(v) } else { None }\n        }\n    });\n\n    let out_dims: Vec<i64> = if let Some(sizes) = sizes_opt {\n        sizes\n    } else if let Some(scales) = scales_opt {\n        x.dims\n            .iter()\n            .zip(&scales)\n            .map(|(&d, &s)| (d as f32 * s) as i64)\n            .collect()\n    } else {\n        return None;\n    };\n    let total_out = broadcast_total(&out_dims)?;\n\n    let scales_eff: Vec<f32> = x\n        .dims\n        .iter()\n        .zip(&out_dims)\n        .map(|(&s, &o)| o as f32 / s as f32)\n        .collect();\n\n    let src_stride: Vec<usize> = {\n        let mut s = vec![1usize; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * x.dims[i + 1] as usize;\n        }\n        s\n    };\n\n    let coord = |out_i: i64, d: usize| -> f32 {\n        let out_d = out_dims[d] as f32;\n        let in_d = x.dims[d] as f32;\n        let s = scales_eff[d];\n        match ctm.as_str() {\n            \"half_pixel\" => (out_i as f32 + 0.5) / s - 0.5,\n            \"pytorch_half_pixel\" => {\n                if out_d > 1.0 {\n                    (out_i as f32 + 0.5) / s - 0.5\n                } else {\n                    0.0\n                }\n            }\n            \"align_corners\" => {\n                if out_d > 1.0 {\n                    out_i as f32 * (in_d - 1.0) / (out_d - 1.0)\n                } else {\n                    0.0\n                }\n            }\n            \"asymmetric\" => out_i as f32 / s,\n            _ => (out_i as f32 + 0.5) / s - 0.5,\n        }\n    };\n\n    let mode_kind = match mode.as_str() {\n        \"cubic\" => ResizeMode::Cubic,\n        \"linear\" => ResizeMode::Linear,\n        \"nearest\" => ResizeMode::Nearest,\n        _ => return None,\n    };\n\n    let mut result = vec![0f32; total_out];\n\n    let dst_stride: Vec<usize> = {\n        let mut s = vec![1usize; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * out_dims[i + 1] as usize;\n        }\n        s\n    };\n\n    let resize_axes: Vec<usize> = (0..rank).filter(|&d| x.dims[d] != out_dims[d]).collect();\n    if resize_axes.is_empty() {\n        let t = make_f32_tensor(out_name, &out_dims, &vals, x.data_type);\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    if !(resize_axes.len() == 2\n        && resize_axes[0] + 1 == resize_axes[1]\n        && resize_axes[1] + 1 == rank)\n    {\n        return None;\n    }\n    let h_axis = resize_axes[0];\n    let w_axis = resize_axes[1];\n\n    let outer_total: usize = x.dims[..h_axis].iter().map(|&d| d as usize).product();\n    let in_h = x.dims[h_axis] as usize;\n    let in_w = x.dims[w_axis] as usize;\n    let out_h = out_dims[h_axis] as usize;\n    let out_w = out_dims[w_axis] as usize;\n\n    for outer in 0..outer_total {\n        let in_plane = outer * in_h * in_w;\n        let out_plane = outer * out_h * out_w;\n        for oy in 0..out_h {\n            let sy = coord(oy as i64, h_axis);\n            for ox in 0..out_w {\n                let sx = coord(ox as i64, w_axis);\n                let v = match mode_kind {\n                    ResizeMode::Nearest => {\n                        let yi = nearest_idx(sy, in_h);\n                        let xi = nearest_idx(sx, in_w);\n                        vals[in_plane + yi * in_w + xi]\n                    }\n                    ResizeMode::Linear => sample_linear_2d(\n                        &vals[in_plane..in_plane + in_h * in_w],\n                        in_h,\n                        in_w,\n                        sy,\n                        sx,\n                        exclude_outside,\n                        extrapolation,\n                    ),\n                    ResizeMode::Cubic => sample_cubic_2d(\n                        &vals[in_plane..in_plane + in_h * in_w],\n                        in_h,\n                        in_w,\n                        sy,\n                        sx,\n                        cubic_a,\n                        exclude_outside,\n                        extrapolation,\n                    ),\n                };\n                result[out_plane + oy * out_w + ox] = v;\n            }\n        }\n    }\n    let _ = (src_stride, dst_stride);\n    let t = make_f32_tensor(out_name, &out_dims, &result, x.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\n#[derive(Clone, Copy)]\nenum ResizeMode {\n    Nearest,\n    Linear,\n    Cubic,\n}\n\nfn nearest_idx(s: f32, dim: usize) -> usize {\n    if s < 0.0 {\n        0\n    } else {\n        let i = s.round() as isize;\n        if i >= dim as isize {\n            dim - 1\n        } else {\n            i as usize\n        }\n    }\n}\n\nfn sample_linear_2d(\n    plane: &[f32],\n    h: usize,\n    w: usize,\n    sy: f32,\n    sx: f32,\n    exclude_outside: bool,\n    extrap: f32,\n) -> f32 {\n    let (y0_in, y0) = clamp_axis(sy.floor() as isize, h);\n    let (y1_in, y1) = clamp_axis(sy.floor() as isize + 1, h);\n    let (x0_in, x0) = clamp_axis(sx.floor() as isize, w);\n    let (x1_in, x1) = clamp_axis(sx.floor() as isize + 1, w);\n    if exclude_outside && (!y0_in && !y1_in || !x0_in && !x1_in) {\n        return extrap;\n    }\n    let dy = sy - sy.floor();\n    let dx = sx - sx.floor();\n    let v00 = plane[y0 * w + x0];\n    let v01 = plane[y0 * w + x1];\n    let v10 = plane[y1 * w + x0];\n    let v11 = plane[y1 * w + x1];\n    let a = v00 * (1.0 - dx) + v01 * dx;\n    let b = v10 * (1.0 - dx) + v11 * dx;\n    a * (1.0 - dy) + b * dy\n}\n\n#[allow(clippy::too_many_arguments)]\nfn sample_cubic_2d(\n    plane: &[f32],\n    h: usize,\n    w: usize,\n    sy: f32,\n    sx: f32,\n    a_coef: f32,\n    exclude_outside: bool,\n    extrap: f32,\n) -> f32 {\n    let fx = sx.floor();\n    let fy = sy.floor();\n    let dx = sx - fx;\n    let dy = sy - fy;\n    let wx = cubic_weights(dx, a_coef);\n    let wy = cubic_weights(dy, a_coef);\n    let mut wx_eff = wx;\n    let mut wy_eff = wy;\n    if exclude_outside {\n        for (i, w_ref) in wx_eff.iter_mut().enumerate() {\n            let xi = fx as isize - 1 + i as isize;\n            if xi < 0 || xi >= w as isize {\n                *w_ref = 0.0;\n            }\n        }\n        for (i, w_ref) in wy_eff.iter_mut().enumerate() {\n            let yi = fy as isize - 1 + i as isize;\n            if yi < 0 || yi >= h as isize {\n                *w_ref = 0.0;\n            }\n        }\n        let sx_sum: f32 = wx_eff.iter().sum();\n        let sy_sum: f32 = wy_eff.iter().sum();\n        if sx_sum == 0.0 || sy_sum == 0.0 {\n            return extrap;\n        }\n        for w_ref in &mut wx_eff {\n            *w_ref /= sx_sum;\n        }\n        for w_ref in &mut wy_eff {\n            *w_ref /= sy_sum;\n        }\n    }\n    let mut out = 0f32;\n    for (iy, &wyv) in wy_eff.iter().enumerate() {\n        if wyv == 0.0 {\n            continue;\n        }\n        let yi = (fy as isize - 1 + iy as isize).clamp(0, h as isize - 1) as usize;\n        let mut row_sum = 0f32;\n        for (ix, &wxv) in wx_eff.iter().enumerate() {\n            if wxv == 0.0 {\n                continue;\n            }\n            let xi = (fx as isize - 1 + ix as isize).clamp(0, w as isize - 1) as usize;\n            row_sum += plane[yi * w + xi] * wxv;\n        }\n        out += row_sum * wyv;\n    }\n    out\n}\n\nfn cubic_weights(t: f32, a: f32) -> [f32; 4] {\n    let t1 = 1.0 + t;\n    let t2 = t;\n    let t3 = 1.0 - t;\n    let t4 = 2.0 - t;\n    [\n        cubic_kernel(t1, a),\n        cubic_kernel(t2, a),\n        cubic_kernel(t3, a),\n        cubic_kernel(t4, a),\n    ]\n}\n\nfn cubic_kernel(x: f32, a: f32) -> f32 {\n    let ax = x.abs();\n    if ax <= 1.0 {\n        (a + 2.0) * ax.powi(3) - (a + 3.0) * ax.powi(2) + 1.0\n    } else if ax < 2.0 {\n        a * ax.powi(3) - 5.0 * a * ax.powi(2) + 8.0 * a * ax - 4.0 * a\n    } else {\n        0.0\n    }\n}\n\nfn clamp_axis(i: isize, dim: usize) -> (bool, usize) {\n    if i < 0 {\n        (false, 0)\n    } else if i >= dim as isize {\n        (false, dim - 1)\n    } else {\n        (true, i as usize)\n    }\n}\n\n#[derive(Clone, Copy)]\nenum ReduceOp {\n    Sum,\n    Mean,\n    Max,\n    Min,\n}\n\nfn eval_reduce(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    out_name: &str,\n    op: ReduceOp,\n) -> Option<Vec<(String, TensorProto)>> {\n    let input = inputs[0];\n    let rank = input.dims.len();\n    if rank == 0 {\n        return None;\n    }\n    // Reduce* for non-floating-point tensors would lose precision\n    // through the tensor_to_f32 path below; refuse to fold them so\n    // the compiler can emit a proper integer reduction.\n    if !matches!(\n        input.data_type,\n        TensorProto::FLOAT | TensorProto::DOUBLE | TensorProto::FLOAT16\n    ) {\n        return None;\n    }\n    let keepdims = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"keepdims\")\n        .map(|a| a.i != 0)\n        .unwrap_or(true);\n    let axes: Vec<i64> = if inputs.len() >= 2 {\n        tensor_to_i64(inputs[1])\n    } else {\n        node.attribute\n            .iter()\n            .find(|a| a.name == \"axes\")\n            .map(|a| a.ints.clone())\n            .unwrap_or_else(|| (0..rank as i64).collect())\n    };\n    let norm_axes: Vec<usize> = axes\n        .iter()\n        .map(|&a| {\n            if a < 0 {\n                (rank as i64 + a) as usize\n            } else {\n                a as usize\n            }\n        })\n        .collect();\n    for &ax in &norm_axes {\n        if ax >= rank {\n            return None;\n        }\n    }\n    let mut out_dims_full = input.dims.clone();\n    for &ax in &norm_axes {\n        out_dims_full[ax] = 1;\n    }\n    let out_dims: Vec<i64> = if keepdims {\n        out_dims_full.clone()\n    } else {\n        out_dims_full\n            .iter()\n            .enumerate()\n            .filter(|(i, _)| !norm_axes.contains(i))\n            .map(|(_, &d)| d)\n            .collect()\n    };\n    let total_out = broadcast_total(&out_dims_full)?;\n    let total_in = broadcast_total(&input.dims)?;\n    let vals = tensor_to_f32(input);\n    if vals.is_empty() {\n        return None;\n    }\n\n    let reduced_count: i64 = norm_axes.iter().map(|&a| input.dims[a]).product();\n\n    let mut accum = vec![\n        match op {\n            ReduceOp::Sum | ReduceOp::Mean => 0.0f32,\n            ReduceOp::Max => f32::NEG_INFINITY,\n            ReduceOp::Min => f32::INFINITY,\n        };\n        total_out\n    ];\n\n    for (in_idx, &v) in vals.iter().enumerate().take(total_in) {\n        let mut rem = in_idx as i64;\n        let mut out_idx = 0i64;\n        let mut out_stride = 1i64;\n        for i in (0..rank).rev() {\n            let dim_i = input.dims[i];\n            let coord = rem % dim_i;\n            rem /= dim_i;\n            let coord_out = if norm_axes.contains(&i) { 0 } else { coord };\n            out_idx += coord_out * out_stride;\n            out_stride *= out_dims_full[i];\n        }\n        let o = out_idx as usize;\n        accum[o] = match op {\n            ReduceOp::Sum | ReduceOp::Mean => accum[o] + v,\n            ReduceOp::Max => accum[o].max(v),\n            ReduceOp::Min => accum[o].min(v),\n        };\n    }\n    if matches!(op, ReduceOp::Mean) && reduced_count > 0 {\n        for a in &mut accum {\n            *a /= reduced_count as f32;\n        }\n    }\n    let t = make_f32_tensor(out_name, &out_dims, &accum, input.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_cast(\n    node: &NodeProto,\n    input: &TensorProto,\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let target_type = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"to\")\n        .map(|a| a.i as i32)?;\n    match target_type {\n        TensorProto::INT64 => {\n            let vals = tensor_to_f32(input);\n            if vals.is_empty() {\n                return None;\n            }\n            let t = TensorProto {\n                name: out_name.to_string(),\n                data_type: TensorProto::INT64,\n                dims: input.dims.clone(),\n                int64_data: vals.iter().map(|&v| v as i64).collect(),\n                ..Default::default()\n            };\n            Some(vec![(out_name.to_string(), t)])\n        }\n        TensorProto::INT32 => {\n            let vals = tensor_to_f32(input);\n            if vals.is_empty() {\n                return None;\n            }\n            let t = TensorProto {\n                name: out_name.to_string(),\n                data_type: TensorProto::INT32,\n                dims: input.dims.clone(),\n                int32_data: vals.iter().map(|&v| v as i32).collect(),\n                ..Default::default()\n            };\n            Some(vec![(out_name.to_string(), t)])\n        }\n        TensorProto::FLOAT => {\n            let vals = tensor_to_f32(input);\n            if vals.is_empty() {\n                return None;\n            }\n            let t = TensorProto {\n                name: out_name.to_string(),\n                data_type: TensorProto::FLOAT,\n                dims: input.dims.clone(),\n                float_data: vals,\n                ..Default::default()\n            };\n            Some(vec![(out_name.to_string(), t)])\n        }\n        TensorProto::DOUBLE => {\n            let vals = tensor_to_f32(input);\n            if vals.is_empty() {\n                return None;\n            }\n            let t = TensorProto {\n                name: out_name.to_string(),\n                data_type: TensorProto::DOUBLE,\n                dims: input.dims.clone(),\n                double_data: vals.iter().map(|&v| v as f64).collect(),\n                ..Default::default()\n            };\n            Some(vec![(out_name.to_string(), t)])\n        }\n        TensorProto::BOOL => {\n            let vals = tensor_to_f32(input);\n            if vals.is_empty() {\n                return None;\n            }\n            let t = TensorProto {\n                name: out_name.to_string(),\n                data_type: TensorProto::BOOL,\n                dims: input.dims.clone(),\n                int32_data: vals.iter().map(|&v| (v != 0.0) as i32).collect(),\n                ..Default::default()\n            };\n            Some(vec![(out_name.to_string(), t)])\n        }\n        _ => None,\n    }\n}\n\nfn eval_unary_f32(\n    input: &TensorProto,\n    out_name: &str,\n    f: fn(f32) -> f32,\n) -> Option<Vec<(String, TensorProto)>> {\n    let vals: Vec<f32> = tensor_to_f32(input).into_iter().map(f).collect();\n    if vals.is_empty() {\n        return None;\n    }\n    let out_type = input.data_type;\n    let t = make_f32_tensor(out_name, &input.dims, &vals, out_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_binary_f32(\n    inputs: &[&TensorProto],\n    out_name: &str,\n    f: fn(f32, f32) -> f32,\n) -> Option<Vec<(String, TensorProto)>> {\n    if inputs.len() < 2 {\n        return None;\n    }\n    let both_int64 =\n        inputs[0].data_type == TensorProto::INT64 && inputs[1].data_type == TensorProto::INT64;\n    if both_int64 {\n        let a = tensor_to_i64(inputs[0]);\n        let b = tensor_to_i64(inputs[1]);\n        if a.is_empty() || b.is_empty() {\n            return None;\n        }\n        let (result, dims) =\n            broadcast_binary_i64(&a, &inputs[0].dims, &b, &inputs[1].dims, |x, y| {\n                f(x as f32, y as f32) as i64\n            })?;\n        let t = TensorProto {\n            name: out_name.to_string(),\n            dims,\n            data_type: TensorProto::INT64,\n            int64_data: result,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    let a = tensor_to_f32(inputs[0]);\n    let b = tensor_to_f32(inputs[1]);\n    if a.is_empty() || b.is_empty() {\n        return None;\n    }\n    let (result, dims) = broadcast_binary(&a, &inputs[0].dims, &b, &inputs[1].dims, f)?;\n    let t = make_f32_tensor(out_name, &dims, &result, TensorProto::FLOAT);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn broadcast_shape(a_dims: &[i64], b_dims: &[i64]) -> Option<Vec<i64>> {\n    let rank = a_dims.len().max(b_dims.len());\n    let mut out = Vec::with_capacity(rank);\n    for i in 0..rank {\n        let da = if i < rank - a_dims.len() {\n            1\n        } else {\n            a_dims[i - (rank - a_dims.len())]\n        };\n        let db = if i < rank - b_dims.len() {\n            1\n        } else {\n            b_dims[i - (rank - b_dims.len())]\n        };\n        if da == db {\n            out.push(da);\n        } else if da == 1 {\n            out.push(db);\n        } else if db == 1 {\n            out.push(da);\n        } else {\n            return None;\n        }\n    }\n    Some(out)\n}\n\nfn broadcast_index(out_idx: usize, out_dims: &[i64], src_dims: &[i64]) -> usize {\n    let rank = out_dims.len();\n    let src_rank = src_dims.len();\n    let mut idx = 0;\n    let mut stride = 1;\n    for i in (0..src_rank).rev() {\n        let out_i = rank - src_rank + i;\n        let coord = (out_idx / out_dims[out_i + 1..].iter().product::<i64>().max(1) as usize)\n            % out_dims[out_i] as usize;\n        let src_coord = if src_dims[i] == 1 { 0 } else { coord };\n        idx += src_coord * stride;\n        stride *= src_dims[i] as usize;\n    }\n    idx\n}\n\nconst MAX_BROADCAST_ELEMENTS: usize = 100_000_000;\n\nfn broadcast_total(out_dims: &[i64]) -> Option<usize> {\n    let mut total: usize = 1;\n    for &d in out_dims {\n        let d = usize::try_from(d).ok()?;\n        total = total.checked_mul(d)?;\n        if total > MAX_BROADCAST_ELEMENTS {\n            return None;\n        }\n    }\n    Some(total)\n}\n\nfn broadcast_binary(\n    a: &[f32],\n    a_dims: &[i64],\n    b: &[f32],\n    b_dims: &[i64],\n    f: fn(f32, f32) -> f32,\n) -> Option<(Vec<f32>, Vec<i64>)> {\n    let out_dims = broadcast_shape(a_dims, b_dims)?;\n    let total = broadcast_total(&out_dims)?;\n    let mut result = Vec::with_capacity(total);\n    for i in 0..total {\n        let ai = broadcast_index(i, &out_dims, a_dims);\n        let bi = broadcast_index(i, &out_dims, b_dims);\n        result.push(f(a[ai], b[bi]));\n    }\n    Some((result, out_dims))\n}\n\nfn broadcast_binary_i64(\n    a: &[i64],\n    a_dims: &[i64],\n    b: &[i64],\n    b_dims: &[i64],\n    f: impl Fn(i64, i64) -> i64,\n) -> Option<(Vec<i64>, Vec<i64>)> {\n    let out_dims = broadcast_shape(a_dims, b_dims)?;\n    let total = broadcast_total(&out_dims)?;\n    let mut result = Vec::with_capacity(total);\n    for i in 0..total {\n        let ai = broadcast_index(i, &out_dims, a_dims);\n        let bi = broadcast_index(i, &out_dims, b_dims);\n        result.push(f(a[ai], b[bi]));\n    }\n    Some((result, out_dims))\n}\n\nfn eval_reshape(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    if inputs.len() < 2 {\n        return None;\n    }\n    let vals = tensor_to_f32(inputs[0]);\n    let shape = tensor_to_i64(inputs[1]);\n    if vals.is_empty() || shape.is_empty() {\n        return None;\n    }\n    let allowzero = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"allowzero\")\n        .map(|a| a.i != 0)\n        .unwrap_or(false);\n    let mut new_dims: Vec<i64> = shape\n        .iter()\n        .enumerate()\n        .map(|(i, &d)| {\n            if d == 0 {\n                if allowzero {\n                    0\n                } else {\n                    *inputs[0].dims.get(i).unwrap_or(&1)\n                }\n            } else {\n                d\n            }\n        })\n        .collect();\n    if let Some(neg_idx) = new_dims.iter().position(|&d| d == -1) {\n        let known: i64 = new_dims\n            .iter()\n            .enumerate()\n            .filter(|&(i, &d)| i != neg_idx && d > 0)\n            .map(|(_, &d)| d)\n            .product();\n        let total: i64 = vals.len() as i64;\n        if known > 0 {\n            new_dims[neg_idx] = total / known;\n        }\n    }\n    let t = make_f32_tensor(out_name, &new_dims, &vals, inputs[0].data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_squeeze(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let input = inputs[0];\n    let ndim = input.dims.len() as i64;\n    let raw_axes: Vec<i64> = if inputs.len() >= 2 {\n        tensor_to_i64(inputs[1])\n    } else {\n        node.attribute\n            .iter()\n            .find(|a| a.name == \"axes\")\n            .map(|a| a.ints.clone())\n            .unwrap_or_default()\n    };\n    let axes: Vec<usize> = raw_axes\n        .iter()\n        .map(|&a| {\n            if a < 0 {\n                (ndim + a) as usize\n            } else {\n                a as usize\n            }\n        })\n        .collect();\n    if axes.is_empty() {\n        let new_dims: Vec<i64> = input.dims.iter().copied().filter(|&d| d != 1).collect();\n        let vals = tensor_to_f32(input);\n        if vals.is_empty() {\n            return None;\n        }\n        let t = make_f32_tensor(out_name, &new_dims, &vals, input.data_type);\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    for &ax in &axes {\n        if ax >= input.dims.len() || input.dims[ax] != 1 {\n            return None;\n        }\n    }\n    let new_dims: Vec<i64> = input\n        .dims\n        .iter()\n        .enumerate()\n        .filter(|(i, _)| !axes.contains(i))\n        .map(|(_, &d)| d)\n        .collect();\n    let vals = tensor_to_f32(input);\n    if vals.is_empty() {\n        return None;\n    }\n    let t = make_f32_tensor(out_name, &new_dims, &vals, input.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_unsqueeze(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let axes: Vec<i64> = if inputs.len() >= 2 {\n        tensor_to_i64(inputs[1])\n    } else {\n        node.attribute\n            .iter()\n            .find(|a| a.name == \"axes\")\n            .map(|a| a.ints.clone())\n            .unwrap_or_default()\n    };\n    let ndim = inputs[0].dims.len() + axes.len();\n    let mut new_dims = inputs[0].dims.clone();\n    let mut sorted_axes: Vec<usize> = axes\n        .iter()\n        .map(|&a| {\n            if a < 0 {\n                (ndim as i64 + a) as usize\n            } else {\n                a as usize\n            }\n        })\n        .collect();\n    sorted_axes.sort();\n    for &ax in &sorted_axes {\n        if ax <= new_dims.len() {\n            new_dims.insert(ax, 1);\n        }\n    }\n    let vals = tensor_to_f32(inputs[0]);\n    if vals.is_empty() {\n        return None;\n    }\n    let t = make_f32_tensor(out_name, &new_dims, &vals, inputs[0].data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_shape(\n    node: &NodeProto,\n    input: &TensorProto,\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let dims = &input.dims;\n    if dims.is_empty() {\n        return None;\n    }\n    let ndim = dims.len() as i64;\n    let start_attr = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"start\")\n        .map(|a| a.i)\n        .unwrap_or(0);\n    let end_attr = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"end\")\n        .map(|a| a.i)\n        .unwrap_or(ndim);\n    let start = if start_attr < 0 {\n        (ndim + start_attr).max(0) as usize\n    } else {\n        (start_attr as usize).min(dims.len())\n    };\n    let end = if end_attr < 0 {\n        (ndim + end_attr).max(0) as usize\n    } else {\n        (end_attr as usize).min(dims.len())\n    };\n    let sliced: Vec<i64> = if start < end {\n        dims[start..end].to_vec()\n    } else {\n        vec![]\n    };\n    let t = TensorProto {\n        name: out_name.to_string(),\n        data_type: TensorProto::INT64,\n        dims: vec![sliced.len() as i64],\n        int64_data: sliced,\n        ..Default::default()\n    };\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_gather(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    let axis = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"axis\")\n        .map(|a| a.i)\n        .unwrap_or(0);\n    let data = inputs[0];\n    let indices = tensor_to_i64(inputs[1]);\n    if indices.is_empty() || data.dims.is_empty() {\n        return None;\n    }\n    if data.dims.len() == 1 && axis == 0 {\n        let data_vals = tensor_to_f32(data);\n        if data_vals.is_empty() {\n            let data_i64 = tensor_to_i64(data);\n            if data_i64.is_empty() {\n                return None;\n            }\n            let result: Vec<i64> = indices\n                .iter()\n                .map(|&i| {\n                    let idx = if i < 0 {\n                        (data.dims[0] + i) as usize\n                    } else {\n                        i as usize\n                    };\n                    data_i64.get(idx).copied().unwrap_or(0)\n                })\n                .collect();\n            let out_dims = if inputs[1].dims.is_empty() {\n                vec![]\n            } else {\n                inputs[1].dims.clone()\n            };\n            let t = TensorProto {\n                name: out_name.to_string(),\n                data_type: TensorProto::INT64,\n                dims: out_dims,\n                int64_data: result,\n                ..Default::default()\n            };\n            return Some(vec![(out_name.to_string(), t)]);\n        }\n        let result: Vec<f32> = indices\n            .iter()\n            .map(|&i| {\n                let idx = if i < 0 {\n                    (data.dims[0] + i) as usize\n                } else {\n                    i as usize\n                };\n                data_vals.get(idx).copied().unwrap_or(0.0)\n            })\n            .collect();\n        let out_dims = if inputs[1].dims.is_empty() {\n            vec![]\n        } else {\n            inputs[1].dims.clone()\n        };\n        let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type);\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    None\n}\n\nfn eval_slice(inputs: &[&TensorProto], out_name: &str) -> Option<Vec<(String, TensorProto)>> {\n    let data = inputs[0];\n    let starts = tensor_to_i64(inputs[1]);\n    let ends = tensor_to_i64(inputs[2]);\n    if starts.is_empty() || ends.is_empty() {\n        return None;\n    }\n    let axes: Vec<i64> = if inputs.len() > 3 {\n        tensor_to_i64(inputs[3])\n    } else {\n        (0..starts.len() as i64).collect()\n    };\n    let steps: Vec<i64> = if inputs.len() > 4 {\n        tensor_to_i64(inputs[4])\n    } else {\n        vec![1; starts.len()]\n    };\n    if starts.len() != ends.len() || axes.len() != starts.len() || steps.len() != starts.len() {\n        return None;\n    }\n    let rank = data.dims.len();\n    if rank == 0 {\n        return None;\n    }\n\n    let mut per_axis_range: Vec<(i64, i64, i64)> = (0..rank as i64)\n        .map(|d| (0, data.dims[d as usize], 1))\n        .collect();\n    for (i, &raw_axis) in axes.iter().enumerate() {\n        let a = if raw_axis < 0 {\n            rank as i64 + raw_axis\n        } else {\n            raw_axis\n        };\n        if a < 0 || a >= rank as i64 {\n            return None;\n        }\n        let dim = data.dims[a as usize];\n        let step = steps[i];\n        if step == 0 {\n            return None;\n        }\n        if dim == 0 {\n            // Zero-length axis: any slice yields an empty output on that\n            // axis.  Record (0, 0, step) and skip clamping to avoid the\n            // clamp(..., 0, dim - 1) == clamp(..., 0, -1) inverted range.\n            per_axis_range[a as usize] = (0, 0, step);\n            continue;\n        }\n        let raw_start = starts[i];\n        let raw_end = ends[i];\n        let clamp = |v: i64, lo: i64, hi: i64| -> i64 { v.clamp(lo, hi) };\n        let (s, e) = if step > 0 {\n            // ONNX forward slice: start in [0, dim], end in [0, dim],\n            // both treated as exclusive upper bound.\n            let s = clamp(\n                if raw_start < 0 {\n                    dim + raw_start\n                } else {\n                    raw_start\n                },\n                0,\n                dim,\n            );\n            let e = clamp(if raw_end < 0 { dim + raw_end } else { raw_end }, 0, dim);\n            (s, e)\n        } else {\n            // ONNX reverse slice: start in [0, dim-1] (inclusive first\n            // read), end in [-1, dim-1] (exclusive lower bound; -1\n            // means \"walk past index 0\", i.e. include element 0).\n            let s = clamp(\n                if raw_start < 0 {\n                    dim + raw_start\n                } else {\n                    raw_start\n                },\n                0,\n                dim - 1,\n            );\n            let resolved_end = if raw_end == i64::MIN {\n                -1\n            } else if raw_end < 0 {\n                dim + raw_end\n            } else {\n                raw_end\n            };\n            let e = clamp(resolved_end, -1, dim - 1);\n            (s, e)\n        };\n        per_axis_range[a as usize] = (s, e, step);\n    }\n\n    let out_dims: Vec<i64> = per_axis_range\n        .iter()\n        .map(|(s, e, st)| {\n            if *st > 0 {\n                ((e - s + st - 1) / st).max(0)\n            } else {\n                ((s - e + (-st) - 1) / (-st)).max(0)\n            }\n        })\n        .collect();\n    let total = broadcast_total(&out_dims)?;\n    if total == 0 {\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: data.data_type,\n            dims: out_dims,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n\n    let in_strides: Vec<i64> = {\n        let mut s = vec![1i64; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * data.dims[i + 1];\n        }\n        s\n    };\n    let out_strides: Vec<i64> = {\n        let mut s = vec![1i64; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * out_dims[i + 1];\n        }\n        s\n    };\n\n    let src_index = |o: i64| -> i64 {\n        let mut rem = o;\n        let mut src = 0i64;\n        for d in 0..rank {\n            let coord = rem / out_strides[d];\n            rem %= out_strides[d];\n            let (s_axis, _, st) = per_axis_range[d];\n            src += (s_axis + coord * st) * in_strides[d];\n        }\n        src\n    };\n\n    if data.data_type == TensorProto::INT64 {\n        let vals = tensor_to_i64(data);\n        if vals.is_empty() {\n            return None;\n        }\n        let mut result = Vec::with_capacity(total);\n        for o in 0..total {\n            result.push(*vals.get(src_index(o as i64) as usize)?);\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: out_dims,\n            int64_data: result,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n    let vals = tensor_to_f32(data);\n    if vals.is_empty() {\n        return None;\n    }\n    let mut result = Vec::with_capacity(total);\n    for o in 0..total {\n        result.push(*vals.get(src_index(o as i64) as usize)?);\n    }\n    let t = make_f32_tensor(out_name, &out_dims, &result, data.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_scatter_nd(inputs: &[&TensorProto], out_name: &str) -> Option<Vec<(String, TensorProto)>> {\n    let data = inputs[0];\n    let indices = inputs[1];\n    let updates = inputs[2];\n    let rank = data.dims.len();\n    if rank == 0 || indices.dims.is_empty() {\n        return None;\n    }\n    let q = *indices.dims.last()? as usize;\n    if q == 0 || q > rank {\n        return None;\n    }\n    let total = broadcast_total(&data.dims)?;\n    let in_strides: Vec<i64> = {\n        let mut s = vec![1i64; rank];\n        for i in (0..rank.saturating_sub(1)).rev() {\n            s[i] = s[i + 1] * data.dims[i + 1];\n        }\n        s\n    };\n    let trail_size: usize = data.dims[q..].iter().map(|&d| d as usize).product();\n    let scatter_count: usize = indices.dims[..indices.dims.len() - 1]\n        .iter()\n        .map(|&d| d as usize)\n        .product();\n    let idx_vals = tensor_to_i64(indices);\n    if idx_vals.len() != scatter_count * q {\n        return None;\n    }\n\n    if data.data_type == TensorProto::INT64 {\n        let mut buf = tensor_to_i64(data);\n        if buf.len() != total {\n            return None;\n        }\n        let upd_vals = tensor_to_i64(updates);\n        if upd_vals.len() != scatter_count * trail_size {\n            return None;\n        }\n        for s in 0..scatter_count {\n            let mut base = 0i64;\n            for d in 0..q {\n                let mut idx = idx_vals[s * q + d];\n                if idx < 0 {\n                    idx += data.dims[d];\n                }\n                if idx < 0 || idx >= data.dims[d] {\n                    return None;\n                }\n                base += idx * in_strides[d];\n            }\n            for k in 0..trail_size {\n                buf[base as usize + k] = upd_vals[s * trail_size + k];\n            }\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: data.dims.clone(),\n            int64_data: buf,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n\n    let mut buf = tensor_to_f32(data);\n    if buf.len() != total {\n        return None;\n    }\n    let upd_vals = tensor_to_f32(updates);\n    if upd_vals.len() != scatter_count * trail_size {\n        return None;\n    }\n    for s in 0..scatter_count {\n        let mut base = 0i64;\n        for d in 0..q {\n            let mut idx = idx_vals[s * q + d];\n            if idx < 0 {\n                idx += data.dims[d];\n            }\n            if idx < 0 || idx >= data.dims[d] {\n                return None;\n            }\n            base += idx * in_strides[d];\n        }\n        for k in 0..trail_size {\n            buf[base as usize + k] = upd_vals[s * trail_size + k];\n        }\n    }\n    let t = make_f32_tensor(out_name, &data.dims, &buf, data.data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn eval_split(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    output_names: &[String],\n) -> Option<Vec<(String, TensorProto)>> {\n    let data = inputs.first()?;\n    let rank = data.dims.len();\n    if rank == 0 {\n        return None;\n    }\n    let raw_axis = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"axis\")\n        .map(|a| a.i)\n        .unwrap_or(0);\n    let axis = if raw_axis < 0 {\n        rank as i64 + raw_axis\n    } else {\n        raw_axis\n    } as usize;\n    if axis >= rank {\n        return None;\n    }\n    let split_sizes: Vec<i64> = if inputs.len() >= 2 {\n        tensor_to_i64(inputs[1])\n    } else if let Some(attr) = node.attribute.iter().find(|a| a.name == \"split\") {\n        attr.ints.clone()\n    } else {\n        let n = output_names.iter().filter(|s| !s.is_empty()).count() as i64;\n        if n == 0 {\n            return None;\n        }\n        let dim = data.dims[axis];\n        if dim % n != 0 {\n            return None;\n        }\n        vec![dim / n; n as usize]\n    };\n    if split_sizes.iter().sum::<i64>() != data.dims[axis] {\n        return None;\n    }\n    let outputs: Vec<&str> = output_names\n        .iter()\n        .filter(|s| !s.is_empty())\n        .map(String::as_str)\n        .collect();\n    if outputs.len() != split_sizes.len() {\n        return None;\n    }\n\n    let prefix: usize = data.dims[..axis].iter().map(|&d| d as usize).product();\n    let suffix: usize = data.dims[axis + 1..].iter().map(|&d| d as usize).product();\n    let axis_in: usize = data.dims[axis] as usize;\n\n    let mut result = Vec::with_capacity(outputs.len());\n    let is_int64 = data.data_type == TensorProto::INT64;\n    let mut offset = 0usize;\n    for (i, &sz) in split_sizes.iter().enumerate() {\n        let sz_us = usize::try_from(sz).ok()?;\n        if sz_us == 0 {\n            return None;\n        }\n        let mut out_dims = data.dims.clone();\n        out_dims[axis] = sz;\n        let total = prefix * sz_us * suffix;\n        if is_int64 {\n            let vals = tensor_to_i64(data);\n            if vals.is_empty() {\n                return None;\n            }\n            let mut chunk = Vec::with_capacity(total);\n            for p in 0..prefix {\n                for ai in 0..sz_us {\n                    let src_axis = offset + ai;\n                    let src_base = (p * axis_in + src_axis) * suffix;\n                    chunk.extend_from_slice(&vals[src_base..src_base + suffix]);\n                }\n            }\n            let t = TensorProto {\n                name: outputs[i].to_string(),\n                data_type: TensorProto::INT64,\n                dims: out_dims,\n                int64_data: chunk,\n                ..Default::default()\n            };\n            result.push((outputs[i].to_string(), t));\n        } else {\n            let vals = tensor_to_f32(data);\n            if vals.is_empty() {\n                return None;\n            }\n            let mut chunk = Vec::with_capacity(total);\n            for p in 0..prefix {\n                for ai in 0..sz_us {\n                    let src_axis = offset + ai;\n                    let src_base = (p * axis_in + src_axis) * suffix;\n                    chunk.extend_from_slice(&vals[src_base..src_base + suffix]);\n                }\n            }\n            let t = make_f32_tensor(outputs[i], &out_dims, &chunk, data.data_type);\n            result.push((outputs[i].to_string(), t));\n        }\n        offset += sz_us;\n    }\n    Some(result)\n}\n\nfn eval_concat(\n    node: &NodeProto,\n    inputs: &[&TensorProto],\n    out_name: &str,\n) -> Option<Vec<(String, TensorProto)>> {\n    if inputs.is_empty() {\n        return None;\n    }\n    let raw_axis = node\n        .attribute\n        .iter()\n        .find(|a| a.name == \"axis\")\n        .map(|a| a.i)\n        .unwrap_or(0);\n    let rank = inputs[0].dims.len();\n    if !inputs.iter().all(|t| t.dims.len() == rank) {\n        return None;\n    }\n    if rank == 0 {\n        return None;\n    }\n    let axis = if raw_axis < 0 {\n        (rank as i64 + raw_axis) as usize\n    } else {\n        raw_axis as usize\n    };\n    if axis >= rank {\n        return None;\n    }\n    for d in 0..rank {\n        if d == axis {\n            continue;\n        }\n        let expected = inputs[0].dims[d];\n        if !inputs.iter().all(|t| t.dims[d] == expected) {\n            return None;\n        }\n    }\n    let mut out_dims = inputs[0].dims.clone();\n    out_dims[axis] = inputs.iter().map(|t| t.dims[axis]).sum();\n\n    let prefix_size: usize = out_dims[..axis].iter().map(|&d| d as usize).product();\n    let out_axis: usize = out_dims[axis] as usize;\n    let suffix_size: usize = out_dims[axis + 1..].iter().map(|&d| d as usize).product();\n    let out_total = prefix_size\n        .checked_mul(out_axis)?\n        .checked_mul(suffix_size)?;\n    if out_total > MAX_BROADCAST_ELEMENTS {\n        return None;\n    }\n\n    // ONNX Concat requires homogeneous input element types, so the first\n    // input's declared type is authoritative.\n    let is_int64 = inputs[0].data_type == TensorProto::INT64;\n\n    if is_int64 {\n        let mut result: Vec<i64> = vec![0; out_total];\n        let mut axis_offset: usize = 0;\n        for t in inputs {\n            let t_vals = tensor_to_i64(t);\n            let t_axis = t.dims[axis] as usize;\n            if t_axis > 0 && t_vals.is_empty() {\n                return None;\n            }\n            for p in 0..prefix_size {\n                for ai in 0..t_axis {\n                    for s in 0..suffix_size {\n                        let src = (p * t_axis + ai) * suffix_size + s;\n                        let dst = (p * out_axis + axis_offset + ai) * suffix_size + s;\n                        result[dst] = t_vals[src];\n                    }\n                }\n            }\n            axis_offset += t_axis;\n        }\n        let t = TensorProto {\n            name: out_name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: out_dims,\n            int64_data: result,\n            ..Default::default()\n        };\n        return Some(vec![(out_name.to_string(), t)]);\n    }\n\n    let mut result: Vec<f32> = vec![0.0; out_total];\n    let mut axis_offset: usize = 0;\n    for t in inputs {\n        let t_vals = tensor_to_f32(t);\n        let t_axis = t.dims[axis] as usize;\n        if t_axis > 0 && t_vals.is_empty() {\n            return None;\n        }\n        for p in 0..prefix_size {\n            for ai in 0..t_axis {\n                for s in 0..suffix_size {\n                    let src = (p * t_axis + ai) * suffix_size + s;\n                    let dst = (p * out_axis + axis_offset + ai) * suffix_size + s;\n                    result[dst] = t_vals[src];\n                }\n            }\n        }\n        axis_offset += t_axis;\n    }\n    let t = make_f32_tensor(out_name, &out_dims, &result, inputs[0].data_type);\n    Some(vec![(out_name.to_string(), t)])\n}\n\nfn make_f32_tensor(name: &str, dims: &[i64], vals: &[f32], target_type: i32) -> TensorProto {\n    match target_type {\n        TensorProto::INT64 => TensorProto {\n            name: name.to_string(),\n            data_type: TensorProto::INT64,\n            dims: dims.to_vec(),\n            int64_data: vals.iter().map(|&v| v as i64).collect(),\n            ..Default::default()\n        },\n        TensorProto::INT32 => TensorProto {\n            name: name.to_string(),\n            data_type: TensorProto::INT32,\n            dims: dims.to_vec(),\n            int32_data: vals.iter().map(|&v| v as i32).collect(),\n            ..Default::default()\n        },\n        TensorProto::DOUBLE => TensorProto {\n            name: name.to_string(),\n            data_type: TensorProto::DOUBLE,\n            dims: dims.to_vec(),\n            double_data: vals.iter().map(|&v| v as f64).collect(),\n            ..Default::default()\n        },\n        TensorProto::BOOL => TensorProto {\n            name: name.to_string(),\n            data_type: TensorProto::BOOL,\n            dims: dims.to_vec(),\n            int32_data: vals.iter().map(|&v| (v != 0.0) as i32).collect(),\n            ..Default::default()\n        },\n        _ => TensorProto {\n            name: name.to_string(),\n            data_type: TensorProto::FLOAT,\n            dims: dims.to_vec(),\n            float_data: vals.to_vec(),\n            ..Default::default()\n        },\n    }\n}\n\nstruct ConvBnFusion {\n    conv_idx: usize,\n    bn_idx: usize,\n    bn_output: String,\n    w_name: String,\n    bias_name: String,\n    has_bias: bool,\n    orig_bias: Vec<f32>,\n    gamma: Vec<f32>,\n    beta: Vec<f32>,\n    mean: Vec<f32>,\n    var: Vec<f32>,\n    eps: f32,\n    // Initialiser names that become dead after the fusion: the BN's\n    // four parameter inputs (gamma / beta / running mean / running\n    // variance) and, if the Conv had no bias before fusion, the\n    // auto-named \"<w>_fused_bias\" we create.  Collected here so the\n    // caller can purge them in a single post-pass sweep without\n    // re-walking every BN node.\n    stale_bn_param_names: Vec<String>,\n}\n\npub fn fuse_conv_batchnorm(graph: &mut GraphProto) -> usize {\n    let fusions = {\n        let init_map: HashMap<&str, &TensorProto> = graph\n            .initializer\n            .iter()\n            .map(|t| (t.name.as_str(), t))\n            .collect();\n\n        let node_output_map: HashMap<&str, usize> = graph\n            .node\n            .iter()\n            .enumerate()\n            .flat_map(|(i, n)| n.output.iter().map(move |o| (o.as_str(), i)))\n            .collect();\n\n        let mut fusions: Vec<ConvBnFusion> = Vec::new();\n\n        for (bn_idx, bn_node) in graph.node.iter().enumerate() {\n            if bn_node.op_type != \"BatchNormalization\" || bn_node.input.len() < 5 {\n                continue;\n            }\n            let bn_input = &bn_node.input[0];\n            let conv_idx = match node_output_map.get(bn_input.as_str()) {\n                Some(&idx) => idx,\n                None => continue,\n            };\n            let conv_node = &graph.node[conv_idx];\n            if conv_node.op_type != \"Conv\" || conv_node.output.is_empty() {\n                continue;\n            }\n            let consumers: usize = graph\n                .node\n                .iter()\n                .filter(|n| n.input.contains(&conv_node.output[0]))\n                .count();\n            if consumers != 1 {\n                continue;\n            }\n\n            let gamma = match init_map.get(bn_node.input[1].as_str()) {\n                Some(t) => tensor_to_f32(t),\n                None => continue,\n            };\n            let beta = match init_map.get(bn_node.input[2].as_str()) {\n                Some(t) => tensor_to_f32(t),\n                None => continue,\n            };\n            let mean = match init_map.get(bn_node.input[3].as_str()) {\n                Some(t) => tensor_to_f32(t),\n                None => continue,\n            };\n            let var = match init_map.get(bn_node.input[4].as_str()) {\n                Some(t) => tensor_to_f32(t),\n                None => continue,\n            };\n\n            if gamma.is_empty()\n                || gamma.len() != beta.len()\n                || gamma.len() != mean.len()\n                || gamma.len() != var.len()\n            {\n                continue;\n            }\n\n            let bn_output = match bn_node.output.first() {\n                Some(o) if !o.is_empty() => o.clone(),\n                _ => continue,\n            };\n\n            let eps = bn_node\n                .attribute\n                .iter()\n                .find(|a| a.name == \"epsilon\")\n                .map(|a| a.f)\n                .unwrap_or(1e-5);\n\n            let w_name = conv_node.input[1].clone();\n            let has_bias = conv_node.input.len() > 2;\n            let bias_name = if has_bias {\n                conv_node.input[2].clone()\n            } else {\n                format!(\"{}_fused_bias\", w_name)\n            };\n            let orig_bias = if has_bias {\n                init_map\n                    .get(conv_node.input[2].as_str())\n                    .map(|t| tensor_to_f32(t))\n                    .unwrap_or_default()\n            } else {\n                vec![]\n            };\n\n            let stale_bn_param_names = vec![\n                bn_node.input[1].clone(),\n                bn_node.input[2].clone(),\n                bn_node.input[3].clone(),\n                bn_node.input[4].clone(),\n            ];\n\n            fusions.push(ConvBnFusion {\n                conv_idx,\n                bn_idx,\n                bn_output,\n                w_name,\n                bias_name,\n                has_bias,\n                orig_bias,\n                gamma,\n                beta,\n                mean,\n                var,\n                eps,\n                stale_bn_param_names,\n            });\n        }\n\n        fusions\n    };\n\n    if fusions.is_empty() {\n        return 0;\n    }\n\n    let mut removed_bn: HashSet<usize> = HashSet::new();\n    let mut stale_init_names: HashSet<String> = HashSet::new();\n\n    for f in &fusions {\n        let channels = f.gamma.len();\n        let scale: Vec<f32> = (0..channels)\n            .map(|c| f.gamma[c] / (f.var[c] + f.eps).sqrt())\n            .collect();\n\n        let w_ok = if let Some(w_init) = graph.initializer.iter_mut().find(|i| i.name == f.w_name) {\n            let mut w_data = tensor_to_f32(w_init);\n            // tensor_to_f32 returns empty for unsupported dtypes\n            // (e.g. f16 / bf16 weights we don't yet convert); skip\n            // the fusion for this Conv rather than silently clearing\n            // the initializer into a zero-length FLOAT tensor that\n            // would fail every downstream shape check.\n            if w_data.is_empty() {\n                false\n            } else if !w_init.dims.is_empty() && w_init.dims[0] as usize == channels {\n                let per_filter = w_data.len() / channels;\n                for c in 0..channels {\n                    for j in 0..per_filter {\n                        w_data[c * per_filter + j] *= scale[c];\n                    }\n                }\n                w_init.float_data = w_data;\n                w_init.raw_data.clear();\n                // The initialiser may have arrived as half / bfloat\n                // encoded in raw_data; float_data is FLOAT by\n                // definition, so stamp the tensor metadata to match\n                // the new representation.\n                w_init.data_type = TensorProto::FLOAT;\n                true\n            } else {\n                false\n            }\n        } else {\n            false\n        };\n        if !w_ok {\n            continue;\n        }\n\n        let fused_bias: Vec<f32> = (0..channels)\n            .map(|c| {\n                let ob = f.orig_bias.get(c).copied().unwrap_or(0.0);\n                (ob - f.mean[c]) * scale[c] + f.beta[c]\n            })\n            .collect();\n\n        if let Some(b_init) = graph.initializer.iter_mut().find(|i| i.name == f.bias_name) {\n            b_init.float_data = fused_bias;\n            b_init.raw_data.clear();\n            b_init.dims = vec![channels as i64];\n            b_init.data_type = TensorProto::FLOAT;\n        } else {\n            graph.initializer.push(TensorProto {\n                name: f.bias_name.clone(),\n                data_type: TensorProto::FLOAT,\n                dims: vec![channels as i64],\n                float_data: fused_bias,\n                ..Default::default()\n            });\n        }\n\n        let conv_node = &mut graph.node[f.conv_idx];\n        if !f.has_bias {\n            conv_node.input.push(f.bias_name.clone());\n        }\n        conv_node.output[0] = f.bn_output.clone();\n\n        removed_bn.insert(f.bn_idx);\n        stale_init_names.extend(f.stale_bn_param_names.iter().cloned());\n    }\n\n    if !removed_bn.is_empty() {\n        let mut idx = 0;\n        graph.node.retain(|_| {\n            let keep = !removed_bn.contains(&idx);\n            idx += 1;\n            keep\n        });\n    }\n\n    if !stale_init_names.is_empty() {\n        // Only drop BN parameter initialisers that no surviving node\n        // still references.  Rare in practice but cheap to verify\n        // and prevents accidentally deleting an initialiser shared\n        // between a fused Conv+BN and an unrelated node elsewhere.\n        let still_used: HashSet<&str> = graph\n            .node\n            .iter()\n            .flat_map(|n| n.input.iter().map(String::as_str))\n            .collect();\n        graph.initializer.retain(|init| {\n            !stale_init_names.contains(&init.name) || still_used.contains(init.name.as_str())\n        });\n    }\n\n    removed_bn.len()\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/onnx_proto.rs",
    "content": "#[allow(clippy::doc_overindented_list_items)]\npub mod onnx {\n    include!(concat!(env!(\"OUT_DIR\"), \"/onnx.rs\"));\n}\n\nuse std::collections::{HashMap, HashSet};\nuse std::path::Path;\n\nuse prost::Message;\n\nuse crate::error::{DsperseError, Result};\n\npub use onnx::{\n    AttributeProto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto, TypeProto,\n    ValueInfoProto,\n};\n\npub use super::onnx_shapes::{\n    elem_type_from_value_info, resolve_dynamic_input_shapes, set_vi_shape, shape_from_value_info,\n    strip_symbolic_value_info, vi_shape,\n};\n\npub fn load_model(path: &Path) -> Result<ModelProto> {\n    let bytes = crate::utils::limits::read_checked(path)?;\n    ModelProto::decode(bytes.as_slice())\n        .map_err(|e| DsperseError::Slicer(format!(\"decode {}: {e}\", path.display())))\n}\n\nfn canonicalize_node_attributes(nodes: &mut [NodeProto]) {\n    for node in nodes {\n        node.attribute.sort_by(|a, b| a.name.cmp(&b.name));\n        for attr in &mut node.attribute {\n            if let Some(g) = attr.g.as_mut() {\n                canonicalize_node_attributes(&mut g.node);\n            }\n            for g in &mut attr.graphs {\n                canonicalize_node_attributes(&mut g.node);\n            }\n        }\n    }\n}\n\npub fn save_model(model: &ModelProto, path: &Path) -> Result<()> {\n    let mut model = model.clone();\n    if let Some(graph) = model.graph.as_mut() {\n        canonicalize_node_attributes(&mut graph.node);\n    }\n    for func in &mut model.functions {\n        canonicalize_node_attributes(&mut func.node);\n    }\n    if let Some(parent) = path.parent() {\n        std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?;\n    }\n    let bytes = model.encode_to_vec();\n    std::fs::write(path, bytes).map_err(|e| DsperseError::io(e, path))\n}\n\npub fn make_tensor_value_info(name: &str, elem_type: i32, shape: &[i64]) -> ValueInfoProto {\n    ValueInfoProto {\n        name: name.to_string(),\n        r#type: Some(TypeProto {\n            denotation: String::new(),\n            value: Some(onnx::type_proto::Value::TensorType(\n                onnx::type_proto::Tensor {\n                    elem_type,\n                    shape: Some(onnx::TensorShapeProto {\n                        dim: shape\n                            .iter()\n                            .map(|&d| onnx::tensor_shape_proto::Dimension {\n                                denotation: String::new(),\n                                value: Some(onnx::tensor_shape_proto::dimension::Value::DimValue(\n                                    d,\n                                )),\n                            })\n                            .collect(),\n                    }),\n                },\n            )),\n        }),\n        doc_string: String::new(),\n        metadata_props: vec![],\n    }\n}\n\npub fn make_tensor(name: &str, elem_type: i32, dims: &[i64], float_data: Vec<f32>) -> TensorProto {\n    TensorProto {\n        name: name.to_string(),\n        data_type: elem_type,\n        dims: dims.to_vec(),\n        float_data,\n        ..Default::default()\n    }\n}\n\npub fn make_node(\n    op_type: &str,\n    inputs: Vec<String>,\n    outputs: Vec<String>,\n    attributes: Vec<AttributeProto>,\n) -> NodeProto {\n    NodeProto {\n        op_type: op_type.to_string(),\n        input: inputs,\n        output: outputs,\n        attribute: attributes,\n        name: String::new(),\n        domain: String::new(),\n        doc_string: String::new(),\n        overload: String::new(),\n        metadata_props: vec![],\n        device_configurations: vec![],\n    }\n}\n\npub fn make_graph(\n    name: &str,\n    nodes: Vec<NodeProto>,\n    inputs: Vec<ValueInfoProto>,\n    outputs: Vec<ValueInfoProto>,\n    initializers: Vec<TensorProto>,\n) -> GraphProto {\n    GraphProto {\n        name: name.to_string(),\n        node: nodes,\n        input: inputs,\n        output: outputs,\n        initializer: initializers,\n        ..Default::default()\n    }\n}\n\npub fn make_model(graph: GraphProto, opset_version: i64) -> ModelProto {\n    ModelProto {\n        ir_version: 8,\n        graph: Some(graph),\n        opset_import: vec![OperatorSetIdProto {\n            domain: String::new(),\n            version: opset_version,\n        }],\n        ..Default::default()\n    }\n}\n\npub fn make_attribute_ints(name: &str, ints: &[i64]) -> AttributeProto {\n    AttributeProto {\n        name: name.to_string(),\n        r#type: onnx::attribute_proto::AttributeType::Ints as i32,\n        ints: ints.to_vec(),\n        ..Default::default()\n    }\n}\n\npub fn make_attribute_int(name: &str, val: i64) -> AttributeProto {\n    AttributeProto {\n        name: name.to_string(),\n        r#type: onnx::attribute_proto::AttributeType::Int as i32,\n        i: val,\n        ..Default::default()\n    }\n}\n\npub fn get_attribute_ints(node: &NodeProto, name: &str) -> Option<Vec<i64>> {\n    node.attribute\n        .iter()\n        .find(|a| a.name == name)\n        .map(|a| a.ints.clone())\n}\n\npub fn get_attribute_int(node: &NodeProto, name: &str) -> Option<i64> {\n    node.attribute.iter().find(|a| a.name == name).map(|a| a.i)\n}\n\npub fn get_attribute_float(node: &NodeProto, name: &str) -> Option<f32> {\n    node.attribute.iter().find(|a| a.name == name).map(|a| a.f)\n}\n\npub fn make_attribute_float(name: &str, val: f32) -> AttributeProto {\n    AttributeProto {\n        name: name.to_string(),\n        f: val,\n        r#type: 1,\n        ..Default::default()\n    }\n}\n\npub fn tensor_to_i64(tensor: &TensorProto) -> Vec<i64> {\n    if !tensor.int64_data.is_empty() {\n        return tensor.int64_data.clone();\n    }\n    if !tensor.raw_data.is_empty() && tensor.data_type == TensorProto::INT64 {\n        if !tensor.raw_data.len().is_multiple_of(8) {\n            tracing::warn!(\n                tensor = %tensor.name,\n                raw_len = tensor.raw_data.len(),\n                \"misaligned INT64 raw_data, skipping\"\n            );\n            return Vec::new();\n        }\n        return tensor\n            .raw_data\n            .chunks_exact(8)\n            .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))\n            .collect();\n    }\n    if !tensor.int32_data.is_empty() {\n        return tensor.int32_data.iter().map(|&v| v as i64).collect();\n    }\n    Vec::new()\n}\n\npub fn tensor_to_f32(tensor: &TensorProto) -> Vec<f32> {\n    if !tensor.float_data.is_empty() {\n        return tensor.float_data.clone();\n    }\n    if !tensor.raw_data.is_empty() && tensor.data_type == TensorProto::FLOAT {\n        let chunks = tensor.raw_data.chunks_exact(4);\n        if !chunks.remainder().is_empty() {\n            return Vec::new();\n        }\n        return chunks\n            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))\n            .collect();\n    }\n    if !tensor.int64_data.is_empty() {\n        return tensor.int64_data.iter().map(|&v| v as f32).collect();\n    }\n    if !tensor.int32_data.is_empty() {\n        return tensor.int32_data.iter().map(|&v| v as f32).collect();\n    }\n    if !tensor.double_data.is_empty() {\n        return tensor.double_data.iter().map(|&v| v as f32).collect();\n    }\n    if !tensor.raw_data.is_empty() {\n        match tensor.data_type {\n            TensorProto::INT64 => {\n                let chunks = tensor.raw_data.chunks_exact(8);\n                if !chunks.remainder().is_empty() {\n                    return Vec::new();\n                }\n                return chunks\n                    .map(|c| {\n                        i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32\n                    })\n                    .collect();\n            }\n            TensorProto::INT32 => {\n                let chunks = tensor.raw_data.chunks_exact(4);\n                if !chunks.remainder().is_empty() {\n                    return Vec::new();\n                }\n                return chunks\n                    .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)\n                    .collect();\n            }\n            TensorProto::DOUBLE => {\n                let chunks = tensor.raw_data.chunks_exact(8);\n                if !chunks.remainder().is_empty() {\n                    return Vec::new();\n                }\n                return chunks\n                    .map(|c| {\n                        f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32\n                    })\n                    .collect();\n            }\n            _ => {}\n        }\n    }\n    Vec::new()\n}\n\n/// Decode a TensorProto into `Vec<f64>` directly, without going\n/// through `f32`.  FLOAT / DOUBLE / INT32 payloads — in either\n/// the typed `*_data` fields or the little-endian `raw_data`\n/// byte stream — round-trip exactly: DOUBLE keeps its full 52-\n/// bit mantissa, FLOAT widens losslessly, and INT32 is always\n/// within f64's exact-integer range.\n///\n/// INT64 is a partial exception.  f64 exactly represents every\n/// integer in `[-2^53, 2^53]`; INT64 magnitudes beyond 2^53 are\n/// rounded to the nearest representable f64 and are not\n/// preserved bit-for-bit.  This still beats the previous\n/// `tensor_to_f32 -> f64::from(f32)` chain (which truncated at\n/// 2^24) but callers that need full INT64 fidelity must not use\n/// this decoder.\n///\n/// Returns an empty `Vec` on unsupported / unrecognised dtypes\n/// or malformed `raw_data` length so callers can use the\n/// existing `data.is_empty()` skip path.\npub fn tensor_to_f64(tensor: &TensorProto) -> Vec<f64> {\n    if !tensor.double_data.is_empty() {\n        return tensor.double_data.clone();\n    }\n    if !tensor.float_data.is_empty() {\n        return tensor.float_data.iter().map(|&v| f64::from(v)).collect();\n    }\n    if !tensor.int64_data.is_empty() {\n        #[allow(clippy::cast_precision_loss)]\n        return tensor.int64_data.iter().map(|&v| v as f64).collect();\n    }\n    if !tensor.int32_data.is_empty() {\n        return tensor.int32_data.iter().map(|&v| f64::from(v)).collect();\n    }\n    if tensor.raw_data.is_empty() {\n        return Vec::new();\n    }\n    match tensor.data_type {\n        TensorProto::DOUBLE => {\n            let chunks = tensor.raw_data.chunks_exact(8);\n            if !chunks.remainder().is_empty() {\n                return Vec::new();\n            }\n            chunks\n                .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))\n                .collect()\n        }\n        TensorProto::FLOAT => {\n            let chunks = tensor.raw_data.chunks_exact(4);\n            if !chunks.remainder().is_empty() {\n                return Vec::new();\n            }\n            chunks\n                .map(|c| f64::from(f32::from_le_bytes([c[0], c[1], c[2], c[3]])))\n                .collect()\n        }\n        TensorProto::INT64 => {\n            let chunks = tensor.raw_data.chunks_exact(8);\n            if !chunks.remainder().is_empty() {\n                return Vec::new();\n            }\n            #[allow(clippy::cast_precision_loss)]\n            chunks\n                .map(|c| {\n                    i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f64\n                })\n                .collect()\n        }\n        TensorProto::INT32 => {\n            let chunks = tensor.raw_data.chunks_exact(4);\n            if !chunks.remainder().is_empty() {\n                return Vec::new();\n            }\n            chunks\n                .map(|c| f64::from(i32::from_le_bytes([c[0], c[1], c[2], c[3]])))\n                .collect()\n        }\n        _ => Vec::new(),\n    }\n}\n\npub fn build_initializer_map(graph: &GraphProto) -> HashMap<String, &TensorProto> {\n    graph\n        .initializer\n        .iter()\n        .map(|i| (i.name.clone(), i))\n        .collect()\n}\n\npub fn build_value_info_map(graph: &GraphProto) -> HashMap<String, &ValueInfoProto> {\n    let mut map: HashMap<String, &ValueInfoProto> = HashMap::new();\n    for vi in &graph.input {\n        map.insert(vi.name.clone(), vi);\n    }\n    for vi in &graph.output {\n        map.insert(vi.name.clone(), vi);\n    }\n    for vi in &graph.value_info {\n        map.insert(vi.name.clone(), vi);\n    }\n    map\n}\n\nimpl TensorProto {\n    pub const FLOAT: i32 = 1;\n    pub const INT64: i32 = 7;\n    pub const DOUBLE: i32 = 11;\n    pub const INT32: i32 = 6;\n    pub const FLOAT16: i32 = 10;\n    pub const BOOL: i32 = 9;\n}\n\nfn is_paddable_shape(target: &[i64], donor: &[i64]) -> bool {\n    if target.len() != donor.len() || target.is_empty() {\n        return false;\n    }\n    let last = target.len() - 1;\n    target[..last] == donor[..last] && donor[last] < target[last] && donor[last] > 0\n}\n\npub fn validate_initializer_compatibility(\n    initializers: &[TensorProto],\n    donor_init_map: &HashMap<String, &TensorProto>,\n    context: &str,\n) -> Result<()> {\n    for init in initializers {\n        if let Some(donor) = donor_init_map.get(&init.name) {\n            if init.data_type != donor.data_type {\n                return Err(DsperseError::Pipeline(format!(\n                    \"dtype mismatch for initializer '{}' in {context}: slice has dtype {}, consumer has dtype {}\",\n                    init.name, init.data_type, donor.data_type\n                )));\n            }\n            if init.dims != donor.dims {\n                if is_paddable_shape(&init.dims, &donor.dims) {\n                    tracing::info!(\n                        name = %init.name,\n                        target = ?init.dims,\n                        donor = ?donor.dims,\n                        \"donor initializer will be zero-padded on last axis\"\n                    );\n                } else {\n                    return Err(DsperseError::Pipeline(format!(\n                        \"shape mismatch for initializer '{}' in {context}: slice expects {:?}, consumer provides {:?}\",\n                        init.name, init.dims, donor.dims\n                    )));\n                }\n            }\n        } else {\n            tracing::debug!(\n                name = %init.name,\n                context,\n                \"initializer not in donor weights, retaining slice value\"\n            );\n        }\n    }\n    Ok(())\n}\n\nfn pad_float_data(\n    donor_data: &[f32],\n    target_dims: &[i64],\n    donor_dims: &[i64],\n    pad_val: f32,\n) -> Vec<f32> {\n    let last = target_dims.len() - 1;\n    let target_last = target_dims[last] as usize;\n    let donor_last = donor_dims[last] as usize;\n    let rows = donor_data.len() / donor_last.max(1);\n    let mut padded = Vec::with_capacity(rows * target_last);\n    for row in 0..rows {\n        let start = row * donor_last;\n        let end = start + donor_last;\n        padded.extend_from_slice(&donor_data[start..end.min(donor_data.len())]);\n        padded.resize(padded.len() + (target_last - donor_last), pad_val);\n    }\n    padded\n}\n\nfn pad_raw_data_f32(raw: &[u8], target_dims: &[i64], donor_dims: &[i64], pad_val: f32) -> Vec<u8> {\n    let donor_floats: Vec<f32> = raw\n        .chunks_exact(4)\n        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))\n        .collect();\n    let padded = pad_float_data(&donor_floats, target_dims, donor_dims, pad_val);\n    padded.iter().flat_map(|f| f.to_le_bytes()).collect()\n}\n\npub fn replace_initializers(\n    model: &mut ModelProto,\n    donor_init_map: &HashMap<String, &TensorProto>,\n) -> Result<usize> {\n    let graph = model\n        .graph\n        .as_mut()\n        .ok_or_else(|| DsperseError::Pipeline(\"ONNX model missing graph\".into()))?;\n    let mut replaced = 0;\n    for init in &mut graph.initializer {\n        if let Some(donor) = donor_init_map.get(&init.name) {\n            if init.data_type != donor.data_type {\n                return Err(DsperseError::Pipeline(format!(\n                    \"dtype mismatch for initializer '{}' in replace_initializers: slice has dtype {}, consumer has dtype {}\",\n                    init.name, init.data_type, donor.data_type\n                )));\n            }\n            let needs_pad = init.dims != donor.dims && is_paddable_shape(&init.dims, &donor.dims);\n            if init.dims != donor.dims && !needs_pad {\n                return Err(DsperseError::Pipeline(format!(\n                    \"shape mismatch for initializer '{}' in replace_initializers: slice expects {:?}, consumer provides {:?}\",\n                    init.name, init.dims, donor.dims\n                )));\n            }\n            if needs_pad {\n                let is_bias = donor.dims.len() == 1;\n                let pad_val: f32 = if is_bias { -10.0 } else { 0.0 };\n                if !donor.float_data.is_empty() {\n                    init.float_data =\n                        pad_float_data(&donor.float_data, &init.dims, &donor.dims, pad_val);\n                    init.raw_data.clear();\n                } else if !donor.raw_data.is_empty() && donor.data_type == TensorProto::FLOAT {\n                    init.raw_data =\n                        pad_raw_data_f32(&donor.raw_data, &init.dims, &donor.dims, pad_val);\n                    init.float_data.clear();\n                }\n                tracing::info!(\n                    name = %init.name,\n                    from = ?donor.dims,\n                    to = ?init.dims,\n                    \"padded donor initializer\"\n                );\n            } else {\n                init.float_data = donor.float_data.clone();\n                init.raw_data = donor.raw_data.clone();\n                init.double_data = donor.double_data.clone();\n                init.int32_data = donor.int32_data.clone();\n                init.int64_data = donor.int64_data.clone();\n            }\n            replaced += 1;\n        }\n    }\n    Ok(replaced)\n}\n\npub fn build_patched_onnx(\n    slice_onnx: &Path,\n    donor_init_map: &HashMap<String, &TensorProto>,\n) -> Result<tempfile::NamedTempFile> {\n    let mut model = load_model(slice_onnx)?;\n    replace_initializers(&mut model, donor_init_map)?;\n    let tmp = tempfile::NamedTempFile::with_suffix(\".onnx\")\n        .map_err(|e| DsperseError::Pipeline(format!(\"create temp file: {e}\")))?;\n    save_model(&model, tmp.path())?;\n    Ok(tmp)\n}\n\nfn model_opset_version(model: &ModelProto) -> i64 {\n    model\n        .opset_import\n        .iter()\n        .find(|o| o.domain.is_empty() || o.domain == \"ai.onnx\")\n        .map(|o| o.version)\n        .unwrap_or(1)\n}\n\nfn min_opset_for_op(op_type: &str) -> Option<i64> {\n    match op_type {\n        \"GridSample\" => Some(16),\n        \"ScatterND\" => Some(16),\n        \"ScatterElements\" => Some(16),\n        _ => None,\n    }\n}\n\npub fn normalize_opset(model: &mut ModelProto) -> usize {\n    let opset = model_opset_version(model);\n    if opset < 13 {\n        return 0;\n    }\n    let graph = match model.graph.as_mut() {\n        Some(g) => g,\n        None => return 0,\n    };\n    let mut required_opset = opset;\n    for node in graph.node.iter() {\n        if let Some(min) = min_opset_for_op(&node.op_type) {\n            required_opset = required_opset.max(min);\n        }\n    }\n    let mut new_initializers: Vec<TensorProto> = Vec::new();\n    let mut count = 0;\n    for node in &mut graph.node {\n        match node.op_type.as_str() {\n            \"Unsqueeze\" | \"Squeeze\" if node.input.len() == 1 => {\n                if let Some(axes) = get_attribute_ints(node, \"axes\") {\n                    let axes_name = format!(\"{}_axes_const\", node.name);\n                    new_initializers.push(TensorProto {\n                        name: axes_name.clone(),\n                        data_type: TensorProto::INT64,\n                        dims: vec![axes.len() as i64],\n                        int64_data: axes,\n                        ..Default::default()\n                    });\n                    node.input.push(axes_name);\n                    node.attribute.retain(|a| a.name != \"axes\");\n                    count += 1;\n                }\n            }\n            \"Reshape\" if opset < 14 => {\n                let had = node.attribute.iter().any(|a| a.name == \"allowzero\");\n                if had {\n                    node.attribute.retain(|a| a.name != \"allowzero\");\n                    count += 1;\n                }\n            }\n            _ => {}\n        }\n    }\n    graph.initializer.extend(new_initializers);\n    if required_opset > opset {\n        if let Some(entry) = model\n            .opset_import\n            .iter_mut()\n            .find(|o| o.domain.is_empty() || o.domain == \"ai.onnx\")\n        {\n            entry.version = required_opset;\n        }\n        tracing::info!(\n            from = opset,\n            to = required_opset,\n            \"bumped declared opset to match op requirements\"\n        );\n        count += 1;\n    }\n    if count > 0 {\n        tracing::info!(\n            opset = required_opset,\n            fixes = count,\n            \"normalized ONNX opset conventions\"\n        );\n    }\n    count\n}\n\npub fn normalize_for_circuit_backend(model: &mut ModelProto) -> usize {\n    let graph = match model.graph.as_mut() {\n        Some(g) => g,\n        None => return 0,\n    };\n    let folded_names = super::onnx_fold::propagate_constants(graph);\n    let folded = folded_names.len();\n    let fixed = fix_zero_dims(graph);\n    let count = flatten_matmul_inputs(graph) + materialize_reshape_targets(graph) + fixed;\n    let total = folded + count;\n    if total > 0 {\n        tracing::info!(\n            total,\n            folded,\n            \"normalized graph for circuit backend compatibility\"\n        );\n    }\n    total\n}\n\nfn fix_zero_dims(graph: &mut GraphProto) -> usize {\n    let mut shapes: HashMap<String, Vec<i64>> = HashMap::new();\n    for inp in &graph.input {\n        if let Some(s) = shape_from_value_info(inp)\n            && s.iter().all(|&d| d > 0)\n        {\n            shapes.insert(inp.name.clone(), s);\n        }\n    }\n    for init in &graph.initializer {\n        if !init.dims.is_empty() {\n            shapes.insert(init.name.clone(), init.dims.clone());\n        }\n    }\n    for vi in &graph.value_info {\n        if let Some(s) = shape_from_value_info(vi)\n            && s.iter().all(|&d| d > 0)\n            && !shapes.contains_key(&vi.name)\n        {\n            shapes.insert(vi.name.clone(), s);\n        }\n    }\n\n    let mut count = 0;\n    for vi in graph.value_info.iter_mut().chain(graph.output.iter_mut()) {\n        if let Some(new_shape) = shapes.get(&vi.name)\n            && let Some(existing) = shape_from_value_info(vi)\n            && existing.contains(&0)\n        {\n            set_vi_shape(vi, new_shape);\n            count += 1;\n        }\n    }\n\n    if count > 0 {\n        tracing::info!(count, \"resolved zero-valued placeholder dimensions\");\n    }\n    count\n}\n\nfn flatten_matmul_inputs(graph: &mut GraphProto) -> usize {\n    let vi_shapes: HashMap<String, Vec<i64>> = graph\n        .input\n        .iter()\n        .chain(graph.value_info.iter())\n        .chain(graph.output.iter())\n        .filter_map(|vi| shape_from_value_info(vi).map(|s| (vi.name.clone(), s)))\n        .collect();\n\n    let init_shapes: HashMap<String, Vec<i64>> = graph\n        .initializer\n        .iter()\n        .map(|i| (i.name.clone(), i.dims.clone()))\n        .collect();\n\n    let shapes: HashMap<String, Vec<i64>> = vi_shapes.into_iter().chain(init_shapes).collect();\n\n    let elem_types: HashMap<String, i32> = graph\n        .input\n        .iter()\n        .chain(graph.value_info.iter())\n        .chain(graph.output.iter())\n        .filter_map(|vi| elem_type_from_value_info(vi).map(|t| (vi.name.clone(), t)))\n        .chain(\n            graph\n                .initializer\n                .iter()\n                .map(|i| (i.name.clone(), i.data_type)),\n        )\n        .collect();\n\n    let mut new_nodes: Vec<(usize, Vec<NodeProto>)> = Vec::new();\n    let mut new_inits: Vec<TensorProto> = Vec::new();\n    let mut new_vis: Vec<ValueInfoProto> = Vec::new();\n    let mut count = 0;\n\n    for (idx, node) in graph.node.iter().enumerate() {\n        if node.op_type != \"MatMul\" {\n            continue;\n        }\n        let a_name = match node.input.first() {\n            Some(n) if !n.is_empty() => n,\n            _ => continue,\n        };\n        let b_name = match node.input.get(1) {\n            Some(n) if !n.is_empty() => n,\n            _ => continue,\n        };\n        let a_shape = match shapes.get(a_name) {\n            Some(s) if s.len() > 3 => s.clone(),\n            _ => continue,\n        };\n        let b_shape = match shapes.get(b_name) {\n            Some(s) => s.clone(),\n            None => continue,\n        };\n        let out_name = match node.output.first() {\n            Some(n) if !n.is_empty() => n.clone(),\n            _ => continue,\n        };\n\n        let batch_dims = &a_shape[..a_shape.len() - 2];\n        let batch_vol: i64 = batch_dims.iter().product();\n        let m = a_shape[a_shape.len() - 2];\n        let k = a_shape[a_shape.len() - 1];\n\n        let node_tag = if node.name.is_empty() {\n            format!(\"matmul_{idx}\")\n        } else {\n            node.name.clone()\n        };\n        let a_2d_name = format!(\"{a_name}__flat2d_{node_tag}\");\n        let a_2d_shape_name = format!(\"{a_name}__flat2d_shape_{node_tag}\");\n        let a_2d = vec![batch_vol * m, k];\n\n        let mut b_2d_name = b_name.clone();\n        let mut needs_b_reshape = false;\n        let n_dim;\n        if b_shape.len() > 2 {\n            let b_m = b_shape[b_shape.len() - 2];\n            n_dim = b_shape[b_shape.len() - 1];\n            let b_batch: i64 = b_shape[..b_shape.len() - 2].iter().product();\n            if b_batch == 1 {\n                b_2d_name = format!(\"{b_name}__flat2d_{node_tag}\");\n                let b_2d_shape_name = format!(\"{b_name}__flat2d_shape_{node_tag}\");\n                let b_2d = vec![b_batch * b_m, n_dim];\n                new_inits.push(TensorProto {\n                    name: b_2d_shape_name.clone(),\n                    data_type: TensorProto::INT64,\n                    dims: vec![2],\n                    int64_data: b_2d.clone(),\n                    ..Default::default()\n                });\n                let b_elem = elem_types\n                    .get(b_name)\n                    .copied()\n                    .unwrap_or(TensorProto::FLOAT);\n                new_vis.push(make_tensor_value_info(&b_2d_name, b_elem, &b_2d));\n                needs_b_reshape = true;\n            }\n        } else {\n            n_dim = *b_shape.last().unwrap_or(&1);\n        }\n\n        let matmul_out_name = format!(\"{out_name}__matmul2d_{node_tag}\");\n        let matmul_2d_shape = vec![batch_vol * m, n_dim];\n\n        let restore_shape_name = format!(\"{out_name}__restore_shape_{node_tag}\");\n        let mut restored: Vec<i64> = batch_dims.to_vec();\n        restored.push(m);\n        if b_shape.len() > 1 {\n            restored.push(n_dim);\n        }\n\n        new_inits.push(TensorProto {\n            name: a_2d_shape_name.clone(),\n            data_type: TensorProto::INT64,\n            dims: vec![2],\n            int64_data: a_2d.clone(),\n            ..Default::default()\n        });\n        new_inits.push(TensorProto {\n            name: restore_shape_name.clone(),\n            data_type: TensorProto::INT64,\n            dims: vec![restored.len() as i64],\n            int64_data: restored.clone(),\n            ..Default::default()\n        });\n\n        let a_elem = elem_types\n            .get(a_name)\n            .copied()\n            .unwrap_or(TensorProto::FLOAT);\n        new_vis.push(make_tensor_value_info(&a_2d_name, a_elem, &a_2d));\n        new_vis.push(make_tensor_value_info(\n            &matmul_out_name,\n            a_elem,\n            &matmul_2d_shape,\n        ));\n\n        let mut inserted = Vec::new();\n\n        inserted.push(NodeProto {\n            op_type: \"Reshape\".into(),\n            name: format!(\"{}_flatten_a\", node.name),\n            input: vec![a_name.clone(), a_2d_shape_name],\n            output: vec![a_2d_name.clone()],\n            ..Default::default()\n        });\n\n        if needs_b_reshape {\n            let b_2d_shape_name = format!(\"{b_name}__flat2d_shape_{node_tag}\");\n            inserted.push(NodeProto {\n                op_type: \"Reshape\".into(),\n                name: format!(\"{}_flatten_b\", node.name),\n                input: vec![b_name.clone(), b_2d_shape_name],\n                output: vec![b_2d_name.clone()],\n                ..Default::default()\n            });\n        }\n\n        inserted.push(NodeProto {\n            op_type: \"MatMul\".into(),\n            name: node.name.clone(),\n            input: vec![a_2d_name, b_2d_name],\n            output: vec![matmul_out_name.clone()],\n            attribute: node.attribute.clone(),\n            ..Default::default()\n        });\n\n        inserted.push(NodeProto {\n            op_type: \"Reshape\".into(),\n            name: format!(\"{}_restore\", node.name),\n            input: vec![matmul_out_name, restore_shape_name],\n            output: vec![out_name],\n            ..Default::default()\n        });\n\n        new_nodes.push((idx, inserted));\n        count += 1;\n    }\n\n    let mut cumulative_offset: usize = 0;\n    for (idx, nodes) in new_nodes {\n        let pos = idx + cumulative_offset;\n        graph.node.remove(pos);\n        let inserted = nodes.len();\n        for (i, n) in nodes.into_iter().enumerate() {\n            graph.node.insert(pos + i, n);\n        }\n        cumulative_offset += inserted - 1;\n    }\n    graph.initializer.extend(new_inits);\n    graph.value_info.extend(new_vis);\n    count\n}\n\nfn materialize_reshape_targets(graph: &mut GraphProto) -> usize {\n    let mut init_names: HashSet<String> =\n        graph.initializer.iter().map(|i| i.name.clone()).collect();\n    let input_names: HashSet<String> = graph.input.iter().map(|i| i.name.clone()).collect();\n    let produced_names: HashSet<String> = graph\n        .node\n        .iter()\n        .flat_map(|n| n.output.iter().cloned())\n        .collect();\n\n    let vi_shapes: HashMap<String, Vec<i64>> = graph\n        .value_info\n        .iter()\n        .chain(graph.output.iter())\n        .filter_map(|vi| shape_from_value_info(vi).map(|s| (vi.name.clone(), s)))\n        .collect();\n\n    let mut new_inits: Vec<TensorProto> = Vec::new();\n    let mut count = 0;\n\n    for node in &graph.node {\n        if node.op_type != \"Reshape\" {\n            continue;\n        }\n        let shape_input = match node.input.get(1) {\n            Some(n) if !n.is_empty() => n,\n            _ => continue,\n        };\n        if init_names.contains(shape_input)\n            || input_names.contains(shape_input)\n            || produced_names.contains(shape_input)\n        {\n            continue;\n        }\n        let out_name = match node.output.first() {\n            Some(n) if !n.is_empty() => n,\n            _ => continue,\n        };\n        let out_shape = match vi_shapes.get(out_name) {\n            Some(s) if !s.is_empty() && s.iter().all(|&d| d > 0) => s,\n            _ => continue,\n        };\n        new_inits.push(TensorProto {\n            name: shape_input.clone(),\n            data_type: TensorProto::INT64,\n            dims: vec![out_shape.len() as i64],\n            int64_data: out_shape.clone(),\n            ..Default::default()\n        });\n        init_names.insert(shape_input.clone());\n        count += 1;\n    }\n\n    graph.initializer.extend(new_inits);\n    count\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/onnx_shapes.rs",
    "content": "use super::onnx_proto::{ModelProto, ValueInfoProto, onnx};\n\npub fn shape_from_value_info(vi: &ValueInfoProto) -> Option<Vec<i64>> {\n    let tp = vi.r#type.as_ref()?;\n    let onnx::type_proto::Value::TensorType(tensor) = tp.value.as_ref()? else {\n        return None;\n    };\n    let shape_proto = tensor.shape.as_ref()?;\n    let mut dims = Vec::new();\n    for d in &shape_proto.dim {\n        match &d.value {\n            Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => dims.push(*v),\n            _ => return None,\n        }\n    }\n    Some(dims)\n}\n\npub fn elem_type_from_value_info(vi: &ValueInfoProto) -> Option<i32> {\n    let tp = vi.r#type.as_ref()?;\n    let onnx::type_proto::Value::TensorType(tensor) = tp.value.as_ref()? else {\n        return None;\n    };\n    Some(tensor.elem_type)\n}\n\npub fn vi_shape(vi: &ValueInfoProto) -> Vec<i64> {\n    vi.r#type\n        .as_ref()\n        .and_then(|t| match &t.value {\n            Some(onnx::type_proto::Value::TensorType(tt)) => tt.shape.as_ref(),\n            _ => None,\n        })\n        .map(|s| {\n            s.dim\n                .iter()\n                .map(|d| match &d.value {\n                    Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => *v,\n                    _ => 0,\n                })\n                .collect()\n        })\n        .unwrap_or_default()\n}\n\npub fn set_vi_shape(vi: &mut ValueInfoProto, shape: &[i64]) {\n    if let Some(ref mut tp) = vi.r#type\n        && let Some(onnx::type_proto::Value::TensorType(ref mut tt)) = tp.value\n    {\n        tt.shape = Some(onnx::TensorShapeProto {\n            dim: shape\n                .iter()\n                .map(|&d| onnx::tensor_shape_proto::Dimension {\n                    denotation: String::new(),\n                    value: Some(onnx::tensor_shape_proto::dimension::Value::DimValue(d)),\n                })\n                .collect(),\n        });\n    }\n}\n\npub fn strip_symbolic_value_info(model: &mut ModelProto) -> usize {\n    let graph = match model.graph.as_mut() {\n        Some(g) => g,\n        None => return 0,\n    };\n\n    let has_symbolic = |vi: &ValueInfoProto| -> bool {\n        vi.r#type\n            .as_ref()\n            .and_then(|t| match &t.value {\n                Some(onnx::type_proto::Value::TensorType(tt)) => tt.shape.as_ref(),\n                _ => None,\n            })\n            .is_some_and(|s| {\n                s.dim.iter().any(|d| {\n                    matches!(\n                        &d.value,\n                        Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_))\n                    )\n                })\n            })\n    };\n\n    let before = graph.value_info.len();\n    graph.value_info.retain(|vi| !has_symbolic(vi));\n    let removed = before - graph.value_info.len();\n\n    for out in &mut graph.output {\n        if let Some(ref mut tp) = out.r#type\n            && let Some(onnx::type_proto::Value::TensorType(ref mut tt)) = tp.value\n            && let Some(ref mut shape) = tt.shape\n        {\n            for d in &mut shape.dim {\n                if matches!(\n                    &d.value,\n                    Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_))\n                ) {\n                    d.value = None;\n                }\n            }\n        }\n    }\n\n    if removed > 0 {\n        tracing::info!(\n            removed,\n            \"stripped value_info entries with symbolic dimensions\"\n        );\n    }\n    removed\n}\n\npub fn resolve_dynamic_input_shapes(\n    model: &mut ModelProto,\n    explicit_shape: Option<&[i64]>,\n) -> crate::error::Result<usize> {\n    let graph = match model.graph.as_mut() {\n        Some(g) => g,\n        None => return Ok(0),\n    };\n    let has_non_batch_symbolic = |inp: &&ValueInfoProto| -> bool {\n        inp.r#type\n            .as_ref()\n            .and_then(|t| match &t.value {\n                Some(onnx::type_proto::Value::TensorType(tt)) => tt.shape.as_ref(),\n                _ => None,\n            })\n            .is_some_and(|s| {\n                s.dim.iter().skip(1).any(|d| {\n                    matches!(\n                        &d.value,\n                        Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None\n                    )\n                })\n            })\n    };\n    let symbolic_count = graph.input.iter().filter(has_non_batch_symbolic).count();\n    if symbolic_count > 1 && explicit_shape.is_some() {\n        return Err(crate::error::DsperseError::Slicer(format!(\n            \"model has {symbolic_count} inputs with non-batch dynamic dimensions; \\\n             --input-shape applies to a single input. Per-input shapes not yet supported.\"\n        )));\n    }\n\n    let mut resolved = 0;\n    for inp in &mut graph.input {\n        let tp = match inp.r#type.as_mut() {\n            Some(t) => t,\n            None => continue,\n        };\n        let tensor = match &mut tp.value {\n            Some(onnx::type_proto::Value::TensorType(tt)) => tt,\n            _ => continue,\n        };\n        let shape = match tensor.shape.as_mut() {\n            Some(s) => s,\n            None => continue,\n        };\n        let has_symbolic = shape.dim.iter().any(|d| {\n            matches!(\n                &d.value,\n                Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None\n            )\n        });\n        if !has_symbolic {\n            continue;\n        }\n        if let Some(explicit) = explicit_shape {\n            if explicit.len() != shape.dim.len() {\n                return Err(crate::error::DsperseError::Slicer(format!(\n                    \"input '{}' has rank {} but --input-shape provides {} dims\",\n                    inp.name,\n                    shape.dim.len(),\n                    explicit.len()\n                )));\n            }\n            for (d, &v) in shape.dim.iter_mut().zip(explicit.iter()) {\n                if let Some(onnx::tensor_shape_proto::dimension::Value::DimValue(existing)) =\n                    &d.value\n                {\n                    if *existing != v {\n                        return Err(crate::error::DsperseError::Slicer(format!(\n                            \"input '{}': --input-shape dim {} conflicts with fixed dim {}\",\n                            inp.name, v, existing\n                        )));\n                    }\n                } else {\n                    d.value = Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v));\n                }\n            }\n            tracing::info!(input = %inp.name, shape = ?explicit, \"applied explicit input shape\");\n            resolved += 1;\n            continue;\n        }\n        let non_batch_symbolic = shape.dim.iter().skip(1).any(|d| {\n            matches!(\n                &d.value,\n                Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None\n            )\n        });\n        if non_batch_symbolic {\n            let dim_names: Vec<String> = shape\n                .dim\n                .iter()\n                .map(|d| match &d.value {\n                    Some(onnx::tensor_shape_proto::dimension::Value::DimParam(s)) => s.clone(),\n                    Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => v.to_string(),\n                    None => \"?\".into(),\n                })\n                .collect();\n            return Err(crate::error::DsperseError::Slicer(format!(\n                \"model input '{}' has dynamic dimensions [{}]; provide --input-shape to set concrete values\",\n                inp.name,\n                dim_names.join(\", \")\n            )));\n        }\n        shape.dim[0].value = Some(onnx::tensor_shape_proto::dimension::Value::DimValue(1));\n        tracing::info!(input = %inp.name, \"defaulted batch dimension to 1\");\n        resolved += 1;\n    }\n    Ok(resolved)\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/onnx_slicer.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::Path;\n\nuse super::analyzer::{self, AnalysisResult, NodeAnalysis};\nuse super::autotiler;\nuse super::materializer;\nuse super::onnx_proto;\nuse crate::error::{DsperseError, Result};\nuse crate::schema::metadata::{\n    Dependencies, ModelMetadata, SliceMetadata, SliceShapeWrapper, TensorShape,\n};\nuse crate::schema::tiling::DimSplitInfo;\n\npub fn slice_model(\n    onnx_path: &Path,\n    output_path: Option<&Path>,\n    tile_size: Option<usize>,\n    jstprove_ops: &[&str],\n    input_shape: Option<&[i64]>,\n) -> Result<ModelMetadata> {\n    let mut model = onnx_proto::load_model(onnx_path)?;\n    onnx_proto::normalize_opset(&mut model);\n    onnx_proto::resolve_dynamic_input_shapes(&mut model, input_shape)?;\n\n    onnx_proto::strip_symbolic_value_info(&mut model);\n    let folded_constants = super::onnx_fold::fold_constant_nodes(&mut model);\n\n    let tmp_dir = tempfile::tempdir().map_err(|e| DsperseError::io(e, onnx_path))?;\n    let tract_path = tmp_dir.path().join(\"tract_model.onnx\");\n    onnx_proto::save_model(&model, &tract_path)?;\n\n    tracing::info!(\"folding constants and tracing shapes via tract\");\n    let trace_result = super::trace::fold_and_trace_via_tract(&tract_path, &model)?;\n    let mut traced_shapes = trace_result.shapes;\n    let traced_types = trace_result.types;\n\n    if let Some(graph) = model.graph.as_mut() {\n        // Chains of shape-dependent ops (Shape -> Gather -> Reshape,\n        // or nested ConstantOfShape pyramids) expose constants only\n        // after earlier rounds have folded their producers, so run\n        // propagate_constants_with_shapes to a fixpoint.  A small\n        // safety cap prevents an unexpected non-monotonic evaluator\n        // from spinning indefinitely; propagation is monotone by\n        // construction so the loop is expected to converge in O(1)\n        // iterations even for the deepest chains we have observed.\n        const SHAPE_CONST_PROP_ITERATION_CAP: usize = 16;\n        let mut total_folded = 0usize;\n        for pass in 0..SHAPE_CONST_PROP_ITERATION_CAP {\n            let folded = super::onnx_fold::propagate_constants_with_shapes(graph, &traced_shapes);\n            if folded == 0 {\n                break;\n            }\n            total_folded += folded;\n            tracing::info!(pass, folded, \"shape-constant propagation pass\");\n        }\n        if total_folded > 0 {\n            tracing::info!(\n                total_folded,\n                \"propagated shape-derived constants in parent graph\"\n            );\n        }\n    }\n\n    let fused_ln = super::layernorm_fuse::fuse_inline_layernorms(&mut model, &mut traced_shapes);\n    if fused_ln > 0 {\n        tracing::info!(fused_ln, \"fused inline LayerNorm patterns\");\n    }\n\n    let self_div_rewrites =\n        super::self_div_rewrite::rewrite_self_div_to_one(&mut model, &mut traced_shapes);\n    if self_div_rewrites > 0 {\n        tracing::info!(self_div_rewrites, \"rewrote degenerate Div(X, X) nodes\");\n    }\n\n    let missing: Vec<String> = if let Some(graph) = &model.graph {\n        let mut missing = Vec::new();\n        for n in &graph.node {\n            for out in &n.output {\n                if !out.is_empty() && !traced_shapes.contains_key(out) {\n                    missing.push(out.clone());\n                }\n            }\n        }\n        missing\n    } else {\n        Vec::new()\n    };\n    if !missing.is_empty() {\n        tracing::warn!(count = missing.len(), first_few = ?&missing[..missing.len().min(5)], \"unresolved tensor shapes after all inference passes\");\n    }\n\n    let analysis = analyzer::analyze(&model, Some(onnx_path))?;\n\n    let output_dir = output_path.map(|p| p.to_path_buf()).unwrap_or_else(|| {\n        onnx_path\n            .parent()\n            .unwrap_or_else(|| Path::new(\".\"))\n            .join(\"slices\")\n    });\n    std::fs::create_dir_all(&output_dir).map_err(|e| DsperseError::io(e, &output_dir))?;\n\n    let slice_points =\n        determine_slice_points(&analysis, tile_size, jstprove_ops, &model, &traced_shapes);\n    tracing::info!(points = ?slice_points, \"determined slice points\");\n    debug_assert!(\n        !slice_points.is_empty(),\n        \"complete_slice_points guarantees at least [0, end]\"\n    );\n\n    let model_dest = output_dir.join(\"model.onnx\");\n    onnx_proto::save_model(&model, &model_dest)?;\n\n    let segment_ranges = super::build_segment_ranges(&slice_points, None);\n\n    let trimmed_points = &slice_points[..slice_points.len().saturating_sub(1)];\n\n    let mut tiled_info = HashMap::new();\n    let mut dim_split_info: HashMap<usize, (autotiler::DimSplitDetection, Option<String>)> =\n        HashMap::new();\n    for (seg_idx, _) in segment_ranges.iter().enumerate() {\n        let slice_model = materializer::materialize_slice_model(\n            &model,\n            trimmed_points,\n            &traced_shapes,\n            &traced_types,\n            seg_idx,\n        )?;\n        if let Some(detection) = autotiler::detect_tiling_needs(&slice_model, tile_size) {\n            tiled_info.insert(seg_idx, detection);\n            continue;\n        }\n        if let Some(graph) = slice_model.graph.as_ref() {\n            let init_names: HashSet<String> =\n                graph.initializer.iter().map(|t| t.name.clone()).collect();\n            let mut slice_shapes: HashMap<String, Vec<i64>> = HashMap::new();\n            for vi in graph\n                .input\n                .iter()\n                .chain(graph.output.iter())\n                .chain(graph.value_info.iter())\n            {\n                let dims = onnx_proto::vi_shape(vi);\n                if !dims.is_empty() {\n                    slice_shapes.insert(vi.name.clone(), dims);\n                }\n            }\n            for init in &graph.initializer {\n                slice_shapes\n                    .entry(init.name.clone())\n                    .or_insert_with(|| init.dims.to_vec());\n            }\n            for (name, shape) in &traced_shapes {\n                slice_shapes\n                    .entry(name.clone())\n                    .or_insert_with(|| shape.clone());\n            }\n            if let Some(detection) = autotiler::detect_dim_split(\n                &graph.node,\n                &slice_shapes,\n                &init_names,\n                autotiler::model_opset(&model),\n            ) {\n                // Build a tentative DimSplitInfo to attempt template creation.\n                // Only record the detection if the template materializes\n                // successfully, so the metadata never carries dim_split\n                // entries that can't be fulfilled at runtime.\n                let tentative_info = DimSplitInfo::from_detection(&detection, seg_idx, None);\n                let slice_dir = output_dir.join(format!(\"slice_{seg_idx}\")).join(\"payload\");\n                std::fs::create_dir_all(&slice_dir).map_err(|e| DsperseError::io(e, &slice_dir))?;\n                match autotiler::create_dim_split_template(\n                    &slice_model,\n                    &tentative_info,\n                    &slice_dir,\n                    Some(&traced_shapes),\n                ) {\n                    Ok(tmpl_path) => {\n                        let tmpl_rel = tmpl_path\n                            .strip_prefix(&output_dir)\n                            .map_err(|_| {\n                                DsperseError::Slicer(format!(\n                                    \"dim-split template path {} is not under output dir {}\",\n                                    tmpl_path.display(),\n                                    output_dir.display()\n                                ))\n                            })?\n                            .to_string_lossy()\n                            .into_owned();\n                        tracing::info!(\n                            slice = seg_idx,\n                            estimated = detection.estimated_constraints,\n                            num_groups = detection.num_groups,\n                            split_kind = ?detection.split_kind,\n                            \"dim-split detected and template created\"\n                        );\n                        dim_split_info.insert(seg_idx, (detection, Some(tmpl_rel)));\n                    }\n                    Err(e) => {\n                        tracing::warn!(\n                            slice = seg_idx,\n                            estimated = detection.estimated_constraints,\n                            error = %e,\n                            \"dim-split detected but template creation failed; \\\n                             slice will be skipped during compilation\"\n                        );\n                        // Record detection with no template path so the\n                        // compiler knows this slice was over-budget and\n                        // should be skipped rather than falling through\n                        // to monolithic compilation.\n                        dim_split_info.insert(seg_idx, (detection, None));\n                    }\n                }\n            }\n        }\n    }\n\n    let slices = build_slice_metadata(\n        &analysis,\n        &slice_points,\n        &segment_ranges,\n        &traced_shapes,\n        &tiled_info,\n        &dim_split_info,\n    );\n\n    let mut metadata = ModelMetadata {\n        original_model: analysis.original_model.clone().unwrap_or_default(),\n        model_type: analysis.model_type.clone(),\n        input_shape: analysis.input_shape.clone(),\n        output_shapes: analysis.output_shapes.clone(),\n        output_names: analysis.output_names.clone(),\n        slice_points: slice_points[..slice_points.len().saturating_sub(1)].to_vec(),\n        slices,\n        dsperse_version: None,\n        dsperse_rev: None,\n        jstprove_version: None,\n        jstprove_rev: None,\n        traced_shapes: Some(traced_shapes),\n        traced_types: Some(traced_types),\n        original_model_path: Some(\"model.onnx\".to_string()),\n        folded_constant_names: folded_constants.into_iter().collect(),\n    };\n    metadata.stamp_version();\n    metadata.save(&output_dir.join(crate::utils::paths::METADATA_FILE))?;\n\n    tracing::info!(\n        slices = metadata.slices.len(),\n        tiled = tiled_info.len(),\n        \"slicing complete\"\n    );\n\n    Ok(metadata)\n}\n\nfn build_slice_metadata(\n    analysis: &AnalysisResult,\n    _slice_points: &[usize],\n    segment_ranges: &[(usize, usize)],\n    traced_shapes: &HashMap<String, Vec<i64>>,\n    tiled_info: &HashMap<usize, autotiler::TilingDetection>,\n    dim_split_info: &HashMap<usize, (autotiler::DimSplitDetection, Option<String>)>,\n) -> Vec<SliceMetadata> {\n    let mut slices = Vec::new();\n\n    for (seg_idx, &(start, end)) in segment_ranges.iter().enumerate() {\n        let dependencies = analyzer::get_segment_dependencies(analysis, start, end);\n\n        let shape = build_shape_from_traced(analysis, start, end, &dependencies, traced_shapes);\n\n        let filename = format!(\"slice_{seg_idx}.onnx\");\n        let relative_path = format!(\"slice_{seg_idx}/payload/{filename}\");\n\n        let mut tiling = None;\n        let mut channel_split = None;\n        if let Some(detection) = tiled_info.get(&seg_idx) {\n            match detection {\n                autotiler::TilingDetection::Spatial {\n                    input_name,\n                    output_name,\n                    input_names,\n                    ndim,\n                    c_in,\n                    c_out,\n                    h,\n                    w,\n                    tile_size: actual_tile,\n                    halo,\n                    tiles_y,\n                    tiles_x,\n                    out_tile,\n                    stride,\n                } => {\n                    tiling = Some(crate::schema::tiling::TilingInfo {\n                        slice_idx: seg_idx,\n                        tile_size: *actual_tile as usize,\n                        num_tiles: (*tiles_y * *tiles_x) as usize,\n                        tiles_y: *tiles_y as usize,\n                        tiles_x: *tiles_x as usize,\n                        halo: *halo,\n                        out_tile: *out_tile,\n                        stride: *stride,\n                        c_in: *c_in as usize,\n                        c_out: *c_out as usize,\n                        input_name: input_name.clone(),\n                        output_name: output_name.clone(),\n                        input_names: input_names.clone(),\n                        ndim: *ndim as usize,\n                        h: *h as usize,\n                        w: *w as usize,\n                        tile: Some(crate::schema::tiling::TileInfo {\n                            path: format!(\"slice_{seg_idx}/payload/tiles/tile.onnx\"),\n                            conv_out: *out_tile,\n                            jstprove_circuit_path: None,\n                        }),\n                        tiles: None,\n                        segment_size: None,\n                        total_elements: None,\n                        original_shape: vec![],\n                    });\n                }\n                autotiler::TilingDetection::FixedSegment {\n                    input_name,\n                    output_name,\n                    input_names,\n                    total_elements,\n                    segment_size,\n                    num_segments,\n                    original_shape,\n                } => {\n                    tiling = Some(crate::schema::tiling::TilingInfo {\n                        slice_idx: seg_idx,\n                        tile_size: *segment_size as usize,\n                        num_tiles: *num_segments as usize,\n                        tiles_y: *num_segments as usize,\n                        tiles_x: 1,\n                        halo: [0, 0, 0, 0],\n                        out_tile: [*segment_size, 1],\n                        stride: [1, 1],\n                        c_in: 1,\n                        c_out: 1,\n                        input_name: input_name.clone(),\n                        output_name: output_name.clone(),\n                        input_names: input_names.clone(),\n                        ndim: 1,\n                        h: *total_elements as usize,\n                        w: 1,\n                        tile: Some(crate::schema::tiling::TileInfo {\n                            path: format!(\"slice_{seg_idx}/payload/tiles/tile.onnx\"),\n                            conv_out: [*segment_size, 1],\n                            jstprove_circuit_path: None,\n                        }),\n                        tiles: None,\n                        segment_size: Some(*segment_size as usize),\n                        total_elements: Some(*total_elements as usize),\n                        original_shape: original_shape.clone(),\n                    });\n                }\n                autotiler::TilingDetection::ChannelSplit {\n                    input_name,\n                    output_name,\n                    c_in,\n                    c_out,\n                    h,\n                    w,\n                    num_groups,\n                    channels_per_group,\n                } => {\n                    channel_split = Some(crate::schema::tiling::ChannelSplitInfo {\n                        slice_idx: seg_idx,\n                        c_in: *c_in as usize,\n                        c_out: *c_out as usize,\n                        num_groups: *num_groups as usize,\n                        channels_per_group: *channels_per_group as usize,\n                        input_name: input_name.clone(),\n                        output_name: output_name.clone(),\n                        h: *h as usize,\n                        w: *w as usize,\n                        out_h: 0,\n                        out_w: 0,\n                        groups: Vec::new(),\n                        bias_path: None,\n                    });\n                }\n            }\n        }\n\n        let dim_split = dim_split_info\n            .get(&seg_idx)\n            .map(|(d, tmpl_rel)| DimSplitInfo::from_detection(d, seg_idx, tmpl_rel.clone()));\n\n        slices.push(SliceMetadata {\n            index: seg_idx,\n            filename: filename.clone(),\n            path: format!(\"payload/{filename}\"),\n            relative_path,\n            shape: SliceShapeWrapper {\n                tensor_shape: shape,\n            },\n            dependencies,\n            tiling,\n            channel_split,\n            dim_split,\n            compilation: Default::default(),\n            slice_metadata: None,\n            slice_metadata_relative_path: None,\n        });\n    }\n\n    slices\n}\n\nfn build_shape_from_traced(\n    _analysis: &AnalysisResult,\n    _start: usize,\n    _end: usize,\n    dependencies: &Dependencies,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n) -> TensorShape {\n    let input_shapes: Vec<Vec<i64>> = dependencies\n        .filtered_inputs\n        .iter()\n        .filter_map(|name| traced_shapes.get(name).cloned())\n        .collect();\n\n    let output_shapes: Vec<Vec<i64>> = dependencies\n        .output\n        .iter()\n        .filter_map(|name| traced_shapes.get(name).cloned())\n        .collect();\n\n    TensorShape {\n        input: input_shapes,\n        output: output_shapes,\n    }\n}\n\nfn determine_slice_points(\n    analysis: &AnalysisResult,\n    tile_size: Option<usize>,\n    jstprove_ops: &[&str],\n    model: &onnx_proto::ModelProto,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n) -> Vec<usize> {\n    let mut points: HashSet<usize> = HashSet::new();\n\n    for node in analysis.nodes.values() {\n        if !node.parameter_details.is_empty() {\n            points.insert(node.index);\n        }\n    }\n\n    let mut sorted_points: Vec<usize> = points.into_iter().collect();\n    sorted_points.sort();\n\n    sorted_points = isolate_conv(&sorted_points, analysis);\n    sorted_points = isolate_expensive_ops(&sorted_points, analysis, model, traced_shapes);\n    sorted_points = optimize_jstprove_slices(&sorted_points, analysis, jstprove_ops);\n\n    if tile_size.is_some() {\n        sorted_points = optimize_for_tiling(&sorted_points, analysis);\n    }\n\n    sorted_points = filter_constant_only_slices(&sorted_points, analysis);\n    sorted_points = merge_control_flow_segments(&sorted_points, analysis);\n    sorted_points.sort();\n    sorted_points.dedup();\n\n    complete_slice_points(&mut sorted_points, analysis);\n    sorted_points\n}\n\nfn optimize_points(\n    points: &[usize],\n    analysis: &AnalysisResult,\n    mutate: impl FnOnce(&mut HashSet<usize>, &[&NodeAnalysis], usize),\n) -> Vec<usize> {\n    let mut updated: HashSet<usize> = points.iter().copied().collect();\n    let mut sorted_nodes: Vec<&NodeAnalysis> = analysis.nodes.values().collect();\n    sorted_nodes.sort_by_key(|n| n.index);\n    let max_idx = sorted_nodes.last().map(|n| n.index).unwrap_or(0);\n    mutate(&mut updated, &sorted_nodes, max_idx);\n    let mut v: Vec<usize> = updated.into_iter().filter(|&p| p <= max_idx).collect();\n    v.sort();\n    v\n}\n\nfn is_spatial_primary(op: &str) -> bool {\n    op == \"Conv\" || op == \"MaxPool\"\n}\n\n/// Insert slice points before AND after every ONNX node whose\n/// estimated constraint count exceeds\n/// [`autotiler::MAX_ESTIMATED_CONSTRAINTS`].  Each \"expensive\" op\n/// (large MatMul, LayerNormalization, Softmax, etc.) becomes a\n/// single-node slice so the dim-split detector sees an unambiguous\n/// shape and the runner doesn't need to trace which axis lives where\n/// through Transpose / Reshape neighbours.  Small ops keep their\n/// existing grouping for circuit catalog reuse.\nfn isolate_expensive_ops(\n    points: &[usize],\n    analysis: &AnalysisResult,\n    model: &onnx_proto::ModelProto,\n    traced_shapes: &HashMap<String, Vec<i64>>,\n) -> Vec<usize> {\n    use jstprove_circuits::api::{EstimationConfig, estimate_op_constraints};\n    let cfg = EstimationConfig::bn254_defaults();\n    let threshold = autotiler::MAX_ESTIMATED_CONSTRAINTS;\n\n    // Build a parallel index: ONNX-node-index -> &NodeProto so we can\n    // resolve input/output tensor names per slicer-node.\n    let onnx_nodes: Vec<&onnx_proto::NodeProto> = model\n        .graph\n        .as_ref()\n        .map(|g| g.node.iter().collect())\n        .unwrap_or_default();\n    // Resolve a tensor's traced shape strictly: every dim must be a\n    // concrete positive value.  Coercing dynamic / -1 / 0 dims to 1\n    // would silently drive the cost estimate to ~zero and let the\n    // very nodes this pass exists to isolate sneak through.  Returning\n    // `None` for an unresolved tensor is the signal to pessimistically\n    // isolate the node anyway.\n    let to_usize_shape = |name: &String| -> Option<Vec<usize>> {\n        let shape = traced_shapes.get(name)?;\n        let mut out = Vec::with_capacity(shape.len());\n        for &d in shape {\n            if d <= 0 {\n                return None;\n            }\n            out.push(d as usize);\n        }\n        Some(out)\n    };\n\n    // Pure elementwise binary ops (Add / Sub / Mul / Div / Pow) are\n    // never isolated.  This is a coupling to jstprove_circuits's\n    // single-op-slice invariants: when an isolated slice contains\n    // exactly one Div with a runtime divisor, one Mul / Sub between\n    // operands of broadcast-incompatible shapes, or one Pow whose\n    // exponent is a non-constant tensor, the per-op layer builder\n    // rejects the slice with a strict-mode error.  When the same\n    // pattern appears inside a larger multi-op slice the\n    // dim-split / LayerNorm fusion machinery rewrites the\n    // surrounding subgraph and the strict check passes.  These ops\n    // are also cheap to compile in absolute terms, so isolating them\n    // buys little proving wall-clock and surfaces the strict-mode\n    // failure more often.\n    //\n    // TODO: revisit when jstprove_circuits relaxes the single-op\n    // invariants (or exposes a \"permissive\" mode) so we can drop\n    // this exemption and let the autotiler decide based on cost.\n    let elementwise_skip: HashSet<&str> = [\"Add\", \"Sub\", \"Mul\", \"Div\", \"Pow\"].into_iter().collect();\n    optimize_points(points, analysis, |updated, sorted_nodes, max_idx| {\n        for node in sorted_nodes {\n            if elementwise_skip.contains(node.node_type.as_str()) {\n                continue;\n            }\n            let Some(onnx_node) = onnx_nodes.get(node.index) else {\n                continue;\n            };\n            // ONNX node inputs / outputs use \"\" to denote an\n            // unbound optional slot (e.g. Conv with no bias, GRU\n            // with no initial_h).  Treating those as unresolved\n            // boundary tensors makes every node carrying an empty\n            // slot pessimistically isolate, even when the real\n            // boundary tensors are fully shape-resolved.  Skip the\n            // empty entries so estimate_op_constraints sees only\n            // the real boundary tensors.\n            let in_shapes: Option<Vec<Vec<usize>>> = onnx_node\n                .input\n                .iter()\n                .filter(|name| !name.is_empty())\n                .map(&to_usize_shape)\n                .collect();\n            let out_shapes: Option<Vec<Vec<usize>>> = onnx_node\n                .output\n                .iter()\n                .filter(|name| !name.is_empty())\n                .map(&to_usize_shape)\n                .collect();\n            // If any boundary tensor is unresolved we cannot give an\n            // honest cost estimate; isolate pessimistically so the\n            // downstream compile path sees a single-op slice and can\n            // either compile it successfully or skip it cleanly,\n            // rather than silently grouping an unbounded op.\n            let isolate = match (in_shapes, out_shapes) {\n                (Some(ins), Some(outs)) => {\n                    estimate_op_constraints(&node.node_type, &ins, &outs, &cfg) > threshold\n                }\n                _ => true,\n            };\n            if isolate {\n                updated.insert(node.index);\n                if node.index < max_idx {\n                    updated.insert(node.index + 1);\n                }\n            }\n        }\n    })\n}\n\nfn isolate_conv(points: &[usize], analysis: &AnalysisResult) -> Vec<usize> {\n    optimize_points(points, analysis, |updated, sorted_nodes, max_idx| {\n        for (pos, node) in sorted_nodes.iter().enumerate() {\n            if is_spatial_primary(&node.node_type) {\n                updated.insert(node.index);\n                let mut produced: HashSet<&str> = node\n                    .dependencies\n                    .output\n                    .iter()\n                    .map(|s| s.as_str())\n                    .collect();\n                let mut end = pos + 1;\n                while end < sorted_nodes.len() {\n                    let candidate = sorted_nodes[end];\n                    if !super::is_slice_passthrough(&candidate.node_type) {\n                        break;\n                    }\n                    let consumes_produced = candidate.dependencies.input.iter().any(|inp| {\n                        !analysis.initializer_names.contains(inp) && produced.contains(inp.as_str())\n                    });\n                    if !consumes_produced {\n                        break;\n                    }\n                    for out in &candidate.dependencies.output {\n                        produced.insert(out.as_str());\n                    }\n                    end += 1;\n                }\n                if end < sorted_nodes.len() && sorted_nodes[end].index <= max_idx {\n                    updated.insert(sorted_nodes[end].index);\n                }\n            }\n        }\n    })\n}\n\nfn optimize_jstprove_slices(\n    points: &[usize],\n    analysis: &AnalysisResult,\n    jstprove_ops: &[&str],\n) -> Vec<usize> {\n    optimize_points(points, analysis, |updated, sorted_nodes, _max_idx| {\n        let is_supported = |n: &NodeAnalysis| jstprove_ops.contains(&n.node_type.as_str());\n        for i in 0..sorted_nodes.len().saturating_sub(1) {\n            if is_supported(sorted_nodes[i]) != is_supported(sorted_nodes[i + 1]) {\n                updated.insert(sorted_nodes[i + 1].index);\n            }\n        }\n    })\n}\n\nfn optimize_for_tiling(points: &[usize], analysis: &AnalysisResult) -> Vec<usize> {\n    optimize_points(points, analysis, |updated, sorted_nodes, _max_idx| {\n        let is_tileable = |n: &NodeAnalysis| {\n            n.node_type == \"Conv\" || n.node_type == \"MaxPool\" || super::is_elementwise(&n.node_type)\n        };\n        for i in 0..sorted_nodes.len().saturating_sub(1) {\n            let curr = sorted_nodes[i];\n            let next = sorted_nodes[i + 1];\n            if !is_tileable(curr) && next.node_type == \"Relu\" {\n                continue;\n            }\n            if is_tileable(curr) != is_tileable(next) {\n                updated.insert(next.index);\n            }\n        }\n    })\n}\n\nfn filter_constant_only_slices(points: &[usize], analysis: &AnalysisResult) -> Vec<usize> {\n    if points.is_empty() {\n        return points.to_vec();\n    }\n    let nodes_by_idx: HashMap<usize, &NodeAnalysis> =\n        analysis.nodes.values().map(|n| (n.index, n)).collect();\n\n    let mut to_remove: HashSet<usize> = HashSet::new();\n    for (i, &end_idx) in points.iter().enumerate() {\n        let start_idx = if i > 0 { points[i - 1] } else { 0 };\n        if start_idx == end_idx {\n            continue;\n        }\n        let all_constant = (start_idx..end_idx).all(|idx| {\n            nodes_by_idx\n                .get(&idx)\n                .map(|n| n.node_type == \"Constant\")\n                .unwrap_or(true)\n        });\n        if all_constant {\n            to_remove.insert(end_idx);\n        }\n    }\n    if !to_remove.is_empty() {\n        tracing::info!(count = to_remove.len(), \"merged constant-only slices\");\n    }\n    points\n        .iter()\n        .filter(|p| !to_remove.contains(p))\n        .copied()\n        .collect()\n}\n\nfn merge_control_flow_segments(points: &[usize], analysis: &AnalysisResult) -> Vec<usize> {\n    let output_to_node_idx: HashMap<&str, usize> = analysis\n        .nodes\n        .values()\n        .flat_map(|n| {\n            n.dependencies\n                .output\n                .iter()\n                .map(move |o| (o.as_str(), n.index))\n        })\n        .collect();\n\n    let mut to_remove: HashSet<usize> = HashSet::new();\n    for node in analysis.nodes.values() {\n        if !super::is_control_flow(&node.node_type) {\n            continue;\n        }\n        for inp in &node.dependencies.input {\n            if let Some(&producer_idx) = output_to_node_idx.get(inp.as_str()) {\n                for &pt in points {\n                    if pt > producer_idx && pt <= node.index {\n                        to_remove.insert(pt);\n                    }\n                }\n            }\n        }\n    }\n\n    if !to_remove.is_empty() {\n        tracing::info!(\n            count = to_remove.len(),\n            \"removed slice points to preserve control flow node dependencies\"\n        );\n    }\n\n    points\n        .iter()\n        .filter(|p| !to_remove.contains(p))\n        .copied()\n        .collect()\n}\n\nfn complete_slice_points(points: &mut Vec<usize>, analysis: &AnalysisResult) {\n    let max_index = analysis.nodes.values().map(|n| n.index).max().unwrap_or(0);\n    let end = max_index + 1;\n    if !points.contains(&0) {\n        points.push(0);\n    }\n    if !points.contains(&end) {\n        points.push(end);\n    }\n    points.sort();\n    points.dedup();\n}\n\npub(crate) fn broadcast_shapes(shapes: &[&Vec<i64>]) -> Option<Vec<i64>> {\n    if shapes.is_empty() {\n        return None;\n    }\n    let max_rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0);\n    let mut result = vec![1i64; max_rank];\n    for shape in shapes {\n        let offset = max_rank - shape.len();\n        for (i, &dim) in shape.iter().enumerate() {\n            let ri = offset + i;\n            if result[ri] == 1 {\n                result[ri] = dim;\n            } else if dim != 1 && dim != result[ri] {\n                return None;\n            }\n        }\n    }\n    Some(result)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use analyzer::NodeDependencies;\n\n    fn make_analysis_with_params(nodes: Vec<(&str, usize, &str, bool)>) -> AnalysisResult {\n        let mut node_map = HashMap::new();\n        for (name, index, op_type, has_params) in &nodes {\n            let mut parameter_details = HashMap::new();\n            if *has_params {\n                parameter_details.insert(\n                    format!(\"{}_weight\", name),\n                    analyzer::ParameterDetail {\n                        shape: vec![3, 3],\n                        size: 9,\n                    },\n                );\n            }\n            node_map.insert(\n                name.to_string(),\n                NodeAnalysis {\n                    index: *index,\n                    slice_name: format!(\"{}_{}\", op_type, index),\n                    node_type: op_type.to_string(),\n                    parameter_details,\n                    dependencies: NodeDependencies {\n                        input: vec![],\n                        output: vec![],\n                    },\n                },\n            );\n        }\n        AnalysisResult {\n            original_model: None,\n            model_type: \"ONNX\".to_string(),\n            node_count: nodes.len(),\n            initializer_count: 0,\n            input_shape: vec![],\n            output_shapes: vec![],\n            output_names: vec![],\n            opset_version: Some(18),\n            nodes: node_map,\n            initializer_names: HashSet::new(),\n        }\n    }\n\n    const TEST_OPS: &[&str] = &[\"Conv\", \"Gemm\", \"MatMul\"];\n\n    #[test]\n    fn complete_slice_points_adds_boundaries() {\n        let analysis = make_analysis_with_params(vec![\n            (\"a\", 0, \"Conv\", false),\n            (\"b\", 1, \"Relu\", false),\n            (\"c\", 2, \"Conv\", false),\n        ]);\n        let mut points = vec![1];\n        complete_slice_points(&mut points, &analysis);\n        assert!(points.contains(&0));\n        assert!(points.contains(&3));\n        assert!(points.contains(&1));\n    }\n\n    #[test]\n    fn complete_slice_points_already_complete() {\n        let analysis =\n            make_analysis_with_params(vec![(\"a\", 0, \"Conv\", false), (\"b\", 1, \"Relu\", false)]);\n        let mut points = vec![0, 2];\n        complete_slice_points(&mut points, &analysis);\n        assert_eq!(points, vec![0, 2]);\n    }\n\n    #[test]\n    fn complete_slice_points_deduplicates() {\n        let analysis = make_analysis_with_params(vec![(\"a\", 0, \"Conv\", false)]);\n        let mut points = vec![0, 0, 1, 1];\n        complete_slice_points(&mut points, &analysis);\n        assert_eq!(points, vec![0, 1]);\n    }\n\n    #[test]\n    fn isolate_conv_inserts_boundaries() {\n        let analysis = make_analysis_with_params(vec![\n            (\"a\", 0, \"Conv\", false),\n            (\"b\", 1, \"Relu\", false),\n            (\"c\", 2, \"MaxPool\", false),\n            (\"d\", 3, \"Conv\", false),\n            (\"e\", 4, \"Relu\", false),\n        ]);\n        let points = vec![0, 3];\n        let result = isolate_conv(&points, &analysis);\n        assert!(result.contains(&0));\n        assert!(result.contains(&1));\n        assert!(result.contains(&3));\n        assert!(result.contains(&4));\n    }\n\n    #[test]\n    fn isolate_conv_no_convs() {\n        let analysis =\n            make_analysis_with_params(vec![(\"a\", 0, \"Relu\", false), (\"b\", 1, \"Reshape\", false)]);\n        let points = vec![0];\n        let result = isolate_conv(&points, &analysis);\n        assert_eq!(result, vec![0]);\n    }\n\n    #[test]\n    fn isolate_maxpool_gets_boundary() {\n        let analysis =\n            make_analysis_with_params(vec![(\"a\", 0, \"Relu\", false), (\"b\", 1, \"MaxPool\", false)]);\n        let points = vec![0];\n        let result = isolate_conv(&points, &analysis);\n        assert_eq!(result, vec![0, 1]);\n    }\n\n    #[test]\n    fn optimize_jstprove_slices_splits_at_boundary() {\n        let analysis = make_analysis_with_params(vec![\n            (\"a\", 0, \"Conv\", false),\n            (\"b\", 1, \"Relu\", false),\n            (\"c\", 2, \"Conv\", false),\n        ]);\n        let points = vec![0];\n        let result = optimize_jstprove_slices(&points, &analysis, TEST_OPS);\n        assert!(result.contains(&1));\n        assert!(result.contains(&2));\n    }\n\n    #[test]\n    fn optimize_jstprove_slices_all_supported() {\n        let analysis =\n            make_analysis_with_params(vec![(\"a\", 0, \"Conv\", false), (\"b\", 1, \"Conv\", false)]);\n        let points = vec![0, 1];\n        let result = optimize_jstprove_slices(&points, &analysis, TEST_OPS);\n        assert_eq!(result, vec![0, 1]);\n    }\n\n    #[test]\n    fn optimize_for_tiling_maxpool_stays_grouped() {\n        let analysis = make_analysis_with_params(vec![\n            (\"a\", 0, \"Conv\", false),\n            (\"b\", 1, \"Relu\", false),\n            (\"c\", 2, \"MaxPool\", false),\n            (\"d\", 3, \"Conv\", false),\n        ]);\n        let points = vec![0, 3];\n        let result = optimize_for_tiling(&points, &analysis);\n        assert!(!result.contains(&2));\n    }\n\n    #[test]\n    fn optimize_for_tiling_splits_at_non_tileable() {\n        let analysis = make_analysis_with_params(vec![\n            (\"a\", 0, \"Conv\", false),\n            (\"b\", 1, \"Relu\", false),\n            (\"c\", 2, \"Reshape\", false),\n            (\"d\", 3, \"Conv\", false),\n        ]);\n        let points = vec![0, 3];\n        let result = optimize_for_tiling(&points, &analysis);\n        assert!(result.contains(&2));\n    }\n\n    #[test]\n    fn optimize_for_tiling_relu_after_non_tileable_kept() {\n        let analysis = make_analysis_with_params(vec![\n            (\"a\", 0, \"MaxPool\", false),\n            (\"b\", 1, \"Relu\", false),\n            (\"c\", 2, \"Conv\", false),\n        ]);\n        let points = vec![0, 2];\n        let result = optimize_for_tiling(&points, &analysis);\n        assert!(!result.contains(&1));\n    }\n\n    #[test]\n    fn filter_constant_only_slices_removes_constant_segments() {\n        let analysis = make_analysis_with_params(vec![\n            (\"a\", 0, \"Constant\", false),\n            (\"b\", 1, \"Constant\", false),\n            (\"c\", 2, \"Conv\", false),\n            (\"d\", 3, \"Relu\", false),\n        ]);\n        let points = vec![2, 4];\n        let result = filter_constant_only_slices(&points, &analysis);\n        assert!(!result.contains(&2));\n        assert!(result.contains(&4));\n    }\n\n    #[test]\n    fn filter_constant_only_slices_keeps_non_constant() {\n        let analysis =\n            make_analysis_with_params(vec![(\"a\", 0, \"Conv\", false), (\"b\", 1, \"Relu\", false)]);\n        let points = vec![1, 2];\n        let result = filter_constant_only_slices(&points, &analysis);\n        assert_eq!(result, vec![1, 2]);\n    }\n\n    #[test]\n    fn filter_constant_only_slices_empty_points() {\n        let analysis = make_analysis_with_params(vec![(\"a\", 0, \"Conv\", false)]);\n        let result = filter_constant_only_slices(&[], &analysis);\n        assert!(result.is_empty());\n    }\n\n    #[test]\n    fn determine_slice_points_includes_parameterized_nodes() {\n        let analysis = make_analysis_with_params(vec![\n            (\"conv0\", 0, \"Conv\", true),\n            (\"relu0\", 1, \"Relu\", false),\n            (\"conv1\", 2, \"Conv\", true),\n            (\"relu1\", 3, \"Relu\", false),\n        ]);\n        let model = onnx_proto::ModelProto::default();\n        let traced = HashMap::new();\n        let points = determine_slice_points(&analysis, None, TEST_OPS, &model, &traced);\n        assert!(points.contains(&0));\n        assert!(points.contains(&2));\n        let max = *points.last().unwrap();\n        assert_eq!(max, 4);\n    }\n\n    #[test]\n    fn determine_slice_points_with_tile_size() {\n        let analysis = make_analysis_with_params(vec![\n            (\"conv0\", 0, \"Conv\", true),\n            (\"relu0\", 1, \"Relu\", false),\n            (\"pool\", 2, \"MaxPool\", false),\n            (\"conv1\", 3, \"Conv\", true),\n        ]);\n        let model = onnx_proto::ModelProto::default();\n        let traced = HashMap::new();\n        let points = determine_slice_points(&analysis, Some(1024), TEST_OPS, &model, &traced);\n        assert!(points.contains(&0));\n        assert!(points.len() >= 3);\n    }\n\n    type NodeSpec<'a> = (&'a str, usize, &'a str, bool, Vec<&'a str>, Vec<&'a str>);\n\n    fn make_analysis_with_deps(nodes: Vec<NodeSpec<'_>>) -> AnalysisResult {\n        let mut node_map = HashMap::new();\n        for (name, index, op_type, has_params, inputs, outputs) in &nodes {\n            let mut parameter_details = HashMap::new();\n            if *has_params {\n                parameter_details.insert(\n                    format!(\"{}_weight\", name),\n                    analyzer::ParameterDetail {\n                        shape: vec![3, 3],\n                        size: 9,\n                    },\n                );\n            }\n            node_map.insert(\n                name.to_string(),\n                NodeAnalysis {\n                    index: *index,\n                    slice_name: format!(\"{}_{}\", op_type, index),\n                    node_type: op_type.to_string(),\n                    parameter_details,\n                    dependencies: NodeDependencies {\n                        input: inputs.iter().map(|s| s.to_string()).collect(),\n                        output: outputs.iter().map(|s| s.to_string()).collect(),\n                    },\n                },\n            );\n        }\n        AnalysisResult {\n            original_model: None,\n            model_type: \"ONNX\".to_string(),\n            node_count: nodes.len(),\n            initializer_count: 0,\n            input_shape: vec![],\n            output_shapes: vec![],\n            output_names: vec![],\n            opset_version: Some(18),\n            nodes: node_map,\n            initializer_names: HashSet::new(),\n        }\n    }\n\n    #[test]\n    fn merge_control_flow_removes_boundary_between_producer_and_loop() {\n        let analysis = make_analysis_with_deps(vec![\n            (\"conv0\", 0, \"Conv\", true, vec![\"x\"], vec![\"conv_out\"]),\n            (\n                \"relu0\",\n                1,\n                \"Relu\",\n                false,\n                vec![\"conv_out\"],\n                vec![\"relu_out\"],\n            ),\n            (\n                \"matmul0\",\n                2,\n                \"MatMul\",\n                true,\n                vec![\"relu_out\"],\n                vec![\"mm_out\"],\n            ),\n            (\n                \"loop0\",\n                3,\n                \"Loop\",\n                false,\n                vec![\"trip\", \"cond\", \"init\", \"relu_out\"],\n                vec![\"loop_out\"],\n            ),\n        ]);\n        let points = vec![0, 2, 4];\n        let result = merge_control_flow_segments(&points, &analysis);\n        assert!(\n            !result.contains(&2),\n            \"slice point 2 separates relu0 (producer of relu_out at idx 1) from Loop (idx 3); must be removed: {:?}\",\n            result\n        );\n    }\n\n    #[test]\n    fn merge_control_flow_preserves_unrelated_boundaries() {\n        let analysis = make_analysis_with_deps(vec![\n            (\"conv0\", 0, \"Conv\", true, vec![\"x\"], vec![\"conv_out\"]),\n            (\n                \"relu0\",\n                1,\n                \"Relu\",\n                false,\n                vec![\"conv_out\"],\n                vec![\"relu_out\"],\n            ),\n            (\n                \"conv1\",\n                2,\n                \"Conv\",\n                true,\n                vec![\"relu_out\"],\n                vec![\"conv1_out\"],\n            ),\n            (\n                \"relu1\",\n                3,\n                \"Relu\",\n                false,\n                vec![\"conv1_out\"],\n                vec![\"relu1_out\"],\n            ),\n            (\n                \"loop0\",\n                4,\n                \"Loop\",\n                false,\n                vec![\"trip\", \"cond\", \"relu1_out\"],\n                vec![\"loop_out\"],\n            ),\n        ]);\n        let points = vec![0, 2, 5];\n        let result = merge_control_flow_segments(&points, &analysis);\n        assert!(\n            result.contains(&2),\n            \"boundary at 2 is between conv0/relu0 and conv1/relu1, should be preserved since Loop only depends on relu1_out (idx 3): {:?}\",\n            result\n        );\n    }\n\n    #[test]\n    fn merge_control_flow_no_control_flow_ops() {\n        let analysis = make_analysis_with_deps(vec![\n            (\"conv0\", 0, \"Conv\", true, vec![\"x\"], vec![\"conv_out\"]),\n            (\"relu0\", 1, \"Relu\", false, vec![\"conv_out\"], vec![\"y\"]),\n        ]);\n        let points = vec![0, 1, 2];\n        let result = merge_control_flow_segments(&points, &analysis);\n        assert_eq!(result, vec![0, 1, 2]);\n    }\n\n    /// Regression for PR #183: isolate_conv's inner grouping walk\n    /// must treat the LAYOUT_OPS set (Reshape / Transpose /\n    /// Flatten / Squeeze / Unsqueeze / Gather) as passthroughs so\n    /// that Conv -> Reshape -> MatMul places the trailing compile\n    /// boundary on the heavy MatMul rather than on the Reshape\n    /// that sits between them.  Before the is_slice_passthrough\n    /// split these ops were absent from is_shape_preserving and\n    /// the walk terminated on the Reshape, isolating it into its\n    /// own slice.\n    #[test]\n    fn isolate_conv_absorbs_reshape_then_boundaries_on_matmul() {\n        let analysis = make_analysis_with_deps(vec![\n            (\"conv0\", 0, \"Conv\", true, vec![\"x\"], vec![\"conv_out\"]),\n            (\n                \"reshape0\",\n                1,\n                \"Reshape\",\n                false,\n                vec![\"conv_out\", \"shape\"],\n                vec![\"reshape_out\"],\n            ),\n            (\n                \"matmul0\",\n                2,\n                \"MatMul\",\n                true,\n                vec![\"reshape_out\", \"matmul0_weight\"],\n                vec![\"matmul_out\"],\n            ),\n        ]);\n        let points = vec![0, 3];\n        let result = isolate_conv(&points, &analysis);\n        assert!(\n            result.contains(&0),\n            \"isolate_conv should insert a boundary at the Conv itself: {result:?}\"\n        );\n        assert!(\n            result.contains(&2),\n            \"is_slice_passthrough should absorb Reshape into the Conv slice and place the trailing boundary on MatMul at index 2: {result:?}\"\n        );\n        assert!(\n            !result.contains(&1),\n            \"Reshape at index 1 must not become its own slice boundary when it sits between a Conv and a heavy op: {result:?}\"\n        );\n    }\n\n    /// Transpose + Squeeze variant so we also cover the other\n    /// layout ops added to LAYOUT_OPS.\n    #[test]\n    fn isolate_conv_absorbs_transpose_chain_then_boundaries_on_matmul() {\n        let analysis = make_analysis_with_deps(vec![\n            (\"conv0\", 0, \"Conv\", true, vec![\"x\"], vec![\"conv_out\"]),\n            (\n                \"transpose0\",\n                1,\n                \"Transpose\",\n                false,\n                vec![\"conv_out\"],\n                vec![\"trans_out\"],\n            ),\n            (\n                \"squeeze0\",\n                2,\n                \"Squeeze\",\n                false,\n                vec![\"trans_out\"],\n                vec![\"sq_out\"],\n            ),\n            (\n                \"matmul0\",\n                3,\n                \"MatMul\",\n                true,\n                vec![\"sq_out\", \"matmul0_weight\"],\n                vec![\"matmul_out\"],\n            ),\n        ]);\n        let points = vec![0, 4];\n        let result = isolate_conv(&points, &analysis);\n        assert!(result.contains(&0));\n        assert!(\n            result.contains(&3),\n            \"Transpose + Squeeze chain should absorb into the Conv slice, leaving MatMul at index 3 as the boundary: {result:?}\"\n        );\n        assert!(!result.contains(&1));\n        assert!(!result.contains(&2));\n    }\n\n    /// Counter-case: a layout op whose input is NOT produced by\n    /// the preceding Conv slice must still break the walk, so the\n    /// consumes_produced guard is exercised.\n    #[test]\n    fn isolate_conv_stops_when_passthrough_consumes_external_input() {\n        let analysis = make_analysis_with_deps(vec![\n            (\"conv0\", 0, \"Conv\", true, vec![\"x\"], vec![\"conv_out\"]),\n            (\n                \"reshape0\",\n                1,\n                \"Reshape\",\n                false,\n                // Reshape consumes an external tensor, not\n                // conv_out, so is_slice_passthrough being true is\n                // not sufficient to absorb it.\n                vec![\"external_y\", \"shape\"],\n                vec![\"reshape_out\"],\n            ),\n        ]);\n        let points = vec![0, 2];\n        let result = isolate_conv(&points, &analysis);\n        assert!(result.contains(&0));\n        assert!(\n            result.contains(&1),\n            \"Reshape that doesn't consume any conv-produced tensor should remain the trailing boundary: {result:?}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/self_div_rewrite.rs",
    "content": "use std::collections::HashMap;\n\nuse super::onnx_proto::{ModelProto, TensorProto};\n\n/// Graph rewrite placeholder: detecting `Div(X, X)` and collapsing it to a\n/// constant-ones tensor is only sound when the element type is a floating\n/// point dtype AND every element of X is finite AND non-zero.  Without a\n/// traced-properties side channel carrying that guarantee the rewrite would\n/// silently turn `0 / 0 = NaN` and integer underflow into `1`.\n///\n/// The earlier implementation rewrote unconditionally and is preserved here\n/// as documentation so that a follow-up can plug it in once\n/// `traced_dtypes` / `traced_all_finite_nonzero` maps are available.\npub fn rewrite_self_div_to_one(\n    _model: &mut ModelProto,\n    _traced_shapes: &mut HashMap<String, Vec<i64>>,\n) -> usize {\n    // Intentionally a no-op: see module doc.  Re-enable behind a proper\n    // traced-properties guard once available.\n    let _ = TensorProto::FLOAT;\n    0\n}\n"
  },
  {
    "path": "crates/dsperse/src/slicer/trace.rs",
    "content": "use std::collections::{HashMap, HashSet};\nuse std::path::Path;\n\nuse super::onnx_proto::ModelProto;\nuse crate::error::{DsperseError, Result};\n\npub(crate) struct TraceResult {\n    pub shapes: HashMap<String, Vec<i64>>,\n    pub types: HashMap<String, i32>,\n}\n\npub(crate) fn fold_and_trace_via_tract(\n    onnx_path: &Path,\n    model: &ModelProto,\n) -> Result<TraceResult> {\n    use tract_onnx::prelude::*;\n    use tract_onnx::tract_hir::infer::InferenceSimplePlan;\n\n    let loop_bodies = collect_loop_bodies(model);\n\n    let tract_path = tag_all_outputs(onnx_path, model)?;\n    let tract_model = std::sync::Arc::new(\n        tract_onnx::onnx()\n            .model_for_path(&tract_path)\n            .map_err(|e| DsperseError::Slicer(format!(\"tract load: {e}\")))?,\n    );\n    if let Err(e) = std::fs::remove_file(&tract_path) {\n        tracing::debug!(path = %tract_path.display(), error = %e, \"failed to remove tagged model\");\n    }\n\n    let plan = InferenceSimplePlan::new(tract_model.clone())\n        .map_err(|e| DsperseError::Slicer(format!(\"plan creation: {e}\")))?;\n\n    let mut state = tract_onnx::tract_core::plan::SimpleState::new(&plan)\n        .map_err(|e| DsperseError::Slicer(format!(\"state creation: {e}\")))?;\n\n    let mut input_tvs: TVec<TValue> = tvec![];\n    for outlet in tract_model\n        .input_outlets()\n        .map_err(|e| DsperseError::Slicer(format!(\"input outlets: {e}\")))?\n    {\n        let fact = tract_model\n            .outlet_fact(*outlet)\n            .map_err(|e| DsperseError::Slicer(format!(\"input fact: {e}\")))?;\n        let tensor = if let Ok(tf) = fact.to_typed_fact() {\n            let shape: Vec<usize> = tf\n                .shape\n                .iter()\n                .map(|d| d.to_i64().unwrap_or(1).max(1) as usize)\n                .collect();\n            Tensor::zero_dt(tf.datum_type, &shape)\n                .map_err(|e| DsperseError::Slicer(format!(\"zero tensor: {e}\")))?\n        } else {\n            Tensor::zero::<f32>(&[1]).expect(\"scalar f32 allocation\")\n        };\n        input_tvs.push(tensor.into_tvalue());\n    }\n\n    let shapes_cell = std::cell::RefCell::new(HashMap::<usize, Vec<Vec<i64>>>::new());\n    let dtypes_cell = std::cell::RefCell::new(HashMap::<usize, Vec<u8>>::new());\n    let failed_nodes = std::cell::RefCell::new(HashSet::<usize>::new());\n\n    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {\n        state.run_plan_with_eval(input_tvs, |session, op_state, node, inputs| {\n            let tainted = node\n                .inputs\n                .iter()\n                .any(|inp| failed_nodes.borrow().contains(&inp.node));\n            let outputs = if tainted {\n                failed_nodes.borrow_mut().insert(node.id);\n                let fallback = inputs.first().cloned().unwrap_or_else(|| {\n                    Tensor::zero::<f32>(&[1])\n                        .expect(\"scalar f32 allocation\")\n                        .into_tvalue()\n                });\n                let n = node.outputs.len().max(1);\n                (0..n).map(|_| fallback.clone()).collect()\n            } else {\n                let coerced = crate::backend::onnx::coerce_tdim_inputs(&inputs);\n                let eval_result = if let Some(st) = op_state {\n                    st.eval(session, node.op.as_op(), coerced)\n                } else {\n                    node.op.eval(coerced)\n                };\n                match eval_result {\n                    Ok(o) => o,\n                    Err(e) => {\n                        if let Some(synth) =\n                            synthesize_loop_outputs(&node.name, &inputs, &loop_bodies)\n                        {\n                            tracing::info!(\n                                node = %node.name,\n                                outputs = synth.len(),\n                                \"synthesized Loop output tensors from body subgraph shapes\"\n                            );\n                            synth\n                        } else {\n                            tracing::warn!(\n                                node = %node.name,\n                                op = %node.op.name(),\n                                error = %e,\n                                \"op eval failed, using input[0] shape as fallback\"\n                            );\n                            failed_nodes.borrow_mut().insert(node.id);\n                            let fallback = inputs.first().cloned().unwrap_or_else(|| {\n                                Tensor::zero::<f32>(&[1])\n                                    .expect(\"scalar f32 allocation\")\n                                    .into_tvalue()\n                            });\n                            let n = node.outputs.len().max(1);\n                            (0..n).map(|_| fallback.clone()).collect()\n                        }\n                    }\n                }\n            };\n            let node_shapes: Vec<Vec<i64>> = outputs\n                .iter()\n                .map(|t| t.shape().iter().map(|&d| d as i64).collect())\n                .collect();\n            let node_dtypes: Vec<u8> = outputs\n                .iter()\n                .map(|t| datum_type_to_onnx(t.datum_type()))\n                .collect();\n            shapes_cell.borrow_mut().insert(node.id, node_shapes);\n            dtypes_cell.borrow_mut().insert(node.id, node_dtypes);\n            Ok::<_, TractError>(outputs)\n        })\n    }));\n\n    match &result {\n        Ok(Ok(_)) => tracing::info!(\"tract inference run succeeded\"),\n        Ok(Err(e)) => {\n            tracing::warn!(error = %e, \"tract inference run produced errors; partial shapes may be available\")\n        }\n        Err(_) => {\n            return Err(DsperseError::Slicer(\n                \"tract inference panicked; no shape data recovered\".into(),\n            ));\n        }\n    }\n\n    let run_shapes = shapes_cell.into_inner();\n    let run_dtypes = dtypes_cell.into_inner();\n    let failed = failed_nodes.into_inner();\n\n    tracing::info!(\n        traced_nodes = run_shapes.len(),\n        \"constant folding and shape capture complete\"\n    );\n\n    let mut shapes: HashMap<String, Vec<i64>> = HashMap::new();\n    let mut types: HashMap<String, i32> = HashMap::new();\n    for (node_id, node_shapes) in &run_shapes {\n        if failed.contains(node_id) {\n            continue;\n        }\n        let node_dtypes = run_dtypes.get(node_id);\n        for (slot, shape) in node_shapes.iter().enumerate() {\n            let dt = node_dtypes.and_then(|d| d.get(slot)).copied().unwrap_or(1) as i32; // 1 = FLOAT\n            let outlet = OutletId::new(*node_id, slot);\n            if let Some(label) = tract_model.outlet_label(outlet)\n                && !label.is_empty()\n            {\n                shapes.insert(label.to_string(), shape.clone());\n                types.insert(label.to_string(), dt);\n            }\n            let node = tract_model.node(*node_id);\n            if !node.name.is_empty() {\n                if slot == 0 {\n                    shapes\n                        .entry(node.name.clone())\n                        .or_insert_with(|| shape.clone());\n                    types.entry(node.name.clone()).or_insert(dt);\n                }\n                let qualified = format!(\"{}:{}\", node.name, slot);\n                shapes\n                    .entry(qualified.clone())\n                    .or_insert_with(|| shape.clone());\n                types.entry(qualified).or_insert(dt);\n            }\n        }\n    }\n\n    if let Some(graph) = &model.graph {\n        let mut extra: Vec<(String, Vec<i64>, Option<i32>)> = Vec::new();\n        for n in &graph.node {\n            for (slot, out) in n.output.iter().enumerate() {\n                if out.is_empty() || shapes.contains_key(out) {\n                    continue;\n                }\n                let key = if slot == 0 {\n                    n.name.clone()\n                } else {\n                    format!(\"{}:{}\", n.name, slot)\n                };\n                if let Some(shape) = shapes.get(&key) {\n                    let dt = types.get(&key).copied();\n                    extra.push((out.clone(), shape.clone(), dt));\n                }\n            }\n        }\n        for (name, shape, dt) in extra {\n            shapes.insert(name.clone(), shape);\n            if let Some(dt) = dt {\n                types.insert(name, dt);\n            }\n        }\n        for init in &graph.initializer {\n            if !init.dims.is_empty() {\n                shapes\n                    .entry(init.name.clone())\n                    .or_insert_with(|| init.dims.clone());\n            }\n            if init.data_type != 0 {\n                types.entry(init.name.clone()).or_insert(init.data_type);\n            }\n        }\n        for inp in &graph.input {\n            if let Some(shape) = super::onnx_shapes::shape_from_value_info(inp) {\n                shapes.entry(inp.name.clone()).or_insert(shape);\n            }\n            if let Some(dt) = super::onnx_shapes::elem_type_from_value_info(inp) {\n                types.entry(inp.name.clone()).or_insert(dt);\n            }\n        }\n        resolve_absorbed_nodes(graph, &mut shapes);\n    }\n\n    tracing::info!(tensors = shapes.len(), \"shape trace complete\");\n    Ok(TraceResult { shapes, types })\n}\n\n/// Save a copy of the ONNX model with every node output declared as a graph\n/// output.  This forces tract to preserve outlet labels for all intermediate\n/// tensors, preventing them from being lost during op fusion.\nfn tag_all_outputs(onnx_path: &Path, model: &ModelProto) -> Result<std::path::PathBuf> {\n    let mut tagged = model.clone();\n    if let Some(ref mut graph) = tagged.graph {\n        let existing: HashSet<String> = graph.output.iter().map(|o| o.name.clone()).collect();\n        for node in &graph.node {\n            for out in &node.output {\n                if !out.is_empty() && !existing.contains(out) {\n                    graph.output.push(super::onnx_proto::ValueInfoProto {\n                        name: out.clone(),\n                        ..Default::default()\n                    });\n                }\n            }\n        }\n    }\n    let dir = onnx_path.parent().unwrap_or_else(|| Path::new(\".\"));\n    let tagged_path = dir.join(format!(\"_tract_tagged_{}.onnx\", std::process::id()));\n    super::onnx_proto::save_model(&tagged, &tagged_path)?;\n    Ok(tagged_path)\n}\n\nfn onnx_elem_type_to_datum(onnx_type: i32) -> Option<tract_onnx::prelude::DatumType> {\n    use tract_onnx::prelude::DatumType;\n    match onnx_type {\n        1 => Some(DatumType::F32),\n        2 => Some(DatumType::U8),\n        3 => Some(DatumType::I8),\n        5 => Some(DatumType::I16),\n        6 => Some(DatumType::I32),\n        7 => Some(DatumType::I64),\n        9 => Some(DatumType::Bool),\n        10 => Some(DatumType::F16),\n        11 => Some(DatumType::F64),\n        12 => Some(DatumType::U32),\n        13 => Some(DatumType::U64),\n        _ => None,\n    }\n}\n\nfn datum_type_to_onnx(dt: tract_onnx::prelude::DatumType) -> u8 {\n    use tract_onnx::prelude::DatumType;\n    match dt {\n        DatumType::F32 => 1,\n        DatumType::U8 => 2,\n        DatumType::I8 => 3,\n        DatumType::U16 => 4,\n        DatumType::I16 => 5,\n        DatumType::I32 => 6,\n        DatumType::I64 => 7,\n        DatumType::Bool => 9,\n        DatumType::F16 => 10,\n        DatumType::F64 => 11,\n        DatumType::U32 => 12,\n        DatumType::U64 => 13,\n        _ => 1,\n    }\n}\n\nstruct LoopBody {\n    num_loop_carried: usize,\n    num_scan: usize,\n    scan_body_output_shapes: Vec<Option<Vec<i64>>>,\n    scan_body_output_dtypes: Vec<Option<i32>>,\n}\n\n/// Collect Loop node body metadata from the ONNX graph.  For scan outputs\n/// whose shapes can be statically determined from the body subgraph, store\n/// the body-side shape (without the leading trip-count dimension).\nfn collect_loop_bodies(model: &ModelProto) -> HashMap<String, LoopBody> {\n    let graph = match model.graph.as_ref() {\n        Some(g) => g,\n        None => return HashMap::new(),\n    };\n\n    let mut known: HashMap<String, Vec<i64>> = HashMap::new();\n    for init in &graph.initializer {\n        if !init.dims.is_empty() {\n            known.insert(init.name.clone(), init.dims.clone());\n        }\n    }\n    for vi in graph\n        .input\n        .iter()\n        .chain(graph.value_info.iter())\n        .chain(graph.output.iter())\n    {\n        if let Some(shape) = super::onnx_shapes::shape_from_value_info(vi) {\n            known.insert(vi.name.clone(), shape);\n        }\n    }\n\n    let mut result = HashMap::new();\n    for node in &graph.node {\n        if node.op_type != \"Loop\" {\n            continue;\n        }\n        let body = match node\n            .attribute\n            .iter()\n            .find(|a| a.name == \"body\")\n            .and_then(|a| a.g.as_ref())\n        {\n            Some(b) => b,\n            None => continue,\n        };\n\n        let num_loop_carried = node.input.len().saturating_sub(2);\n        let num_body_out = body.output.len().saturating_sub(1);\n        let num_scan = num_body_out.saturating_sub(num_loop_carried);\n\n        let mut scan_shapes = Vec::with_capacity(num_scan);\n        let mut scan_dtypes = Vec::with_capacity(num_scan);\n        for j in 0..num_scan {\n            let body_out_idx = 1 + num_loop_carried + j;\n            let body_vi = body.output.get(body_out_idx);\n            let shape =\n                body_vi.and_then(|vi| resolve_body_tensor_shape(&vi.name, body, graph, &known));\n            let dtype = body_vi.and_then(super::onnx_shapes::elem_type_from_value_info);\n            scan_shapes.push(shape);\n            scan_dtypes.push(dtype);\n        }\n\n        result.insert(\n            node.name.clone(),\n            LoopBody {\n                num_loop_carried,\n                num_scan,\n                scan_body_output_shapes: scan_shapes,\n                scan_body_output_dtypes: scan_dtypes,\n            },\n        );\n    }\n    result\n}\n\n/// During tract evaluation, when a Loop node fails, produce correctly-shaped\n/// zero tensors so downstream nodes receive valid inputs and are not tainted.\n///\n/// Loop-carried output shapes come directly from the actual input tensors\n/// (inputs\\[2..\\]).  Scan output shapes come from the pre-analyzed body\n/// subgraph with a leading dimension of 1 (single iteration assumption).\nfn synthesize_loop_outputs(\n    node_name: &str,\n    inputs: &[tract_onnx::prelude::TValue],\n    loop_bodies: &HashMap<String, LoopBody>,\n) -> Option<tract_onnx::prelude::TVec<tract_onnx::prelude::TValue>> {\n    use tract_onnx::prelude::*;\n\n    let body = loop_bodies.get(node_name)?;\n    let mut tvs: TVec<TValue> = tvec![];\n\n    for i in 0..body.num_loop_carried {\n        let init_tensor = inputs.get(i + 2)?;\n        let shape: Vec<usize> = init_tensor.shape().to_vec();\n        let tensor = Tensor::zero_dt(init_tensor.datum_type(), &shape).ok()?;\n        tvs.push(tensor.into_tvalue());\n    }\n\n    for j in 0..body.num_scan {\n        let body_shape = body.scan_body_output_shapes.get(j)?;\n        let shape: Vec<usize> = match body_shape {\n            Some(bs) => {\n                let mut s = vec![1usize];\n                s.extend(bs.iter().map(|&d| d.max(1) as usize));\n                s\n            }\n            None => {\n                tracing::warn!(\n                    node = node_name,\n                    scan_idx = j,\n                    \"scan output shape unknown, using [1,1] placeholder\"\n                );\n                vec![1, 1]\n            }\n        };\n        let dt = body\n            .scan_body_output_dtypes\n            .get(j)\n            .and_then(|d| *d)\n            .and_then(onnx_elem_type_to_datum)\n            .unwrap_or(DatumType::F32);\n        let tensor = Tensor::zero_dt(dt, &shape).ok()?;\n        tvs.push(tensor.into_tvalue());\n    }\n\n    Some(tvs)\n}\n\n/// Resolve shapes for ONNX graph nodes that tract absorbed or renamed,\n/// making them invisible in the tract shape output.  Iterates until no\n/// more progress, using only rules already defined in the slicer module\n/// (shape-preserving ops, binary broadcast).\nfn resolve_absorbed_nodes(\n    graph: &super::onnx_proto::GraphProto,\n    shapes: &mut HashMap<String, Vec<i64>>,\n) {\n    let max_passes = 10;\n    for _ in 0..max_passes {\n        let mut progress = false;\n        for node in &graph.node {\n            for out in &node.output {\n                if out.is_empty() || shapes.contains_key(out) {\n                    continue;\n                }\n                let op = node.op_type.as_str();\n                let shape = if super::is_shape_preserving(op) || op == \"Identity\" {\n                    node.input.first().and_then(|inp| shapes.get(inp).cloned())\n                } else if super::is_binary_arithmetic(op) {\n                    let resolved: Vec<&Vec<i64>> =\n                        node.input.iter().filter_map(|i| shapes.get(i)).collect();\n                    let non_empty = node.input.iter().filter(|i| !i.is_empty()).count();\n                    if resolved.len() == non_empty {\n                        super::onnx_slicer::broadcast_shapes(&resolved)\n                    } else {\n                        None\n                    }\n                } else {\n                    None\n                };\n                if let Some(s) = shape {\n                    shapes.insert(out.clone(), s);\n                    progress = true;\n                }\n            }\n        }\n        if !progress {\n            break;\n        }\n    }\n}\n\nfn resolve_body_tensor_shape(\n    name: &str,\n    body: &super::onnx_proto::GraphProto,\n    outer_graph: &super::onnx_proto::GraphProto,\n    known_shapes: &HashMap<String, Vec<i64>>,\n) -> Option<Vec<i64>> {\n    resolve_body_tensor_shape_inner(name, body, outer_graph, known_shapes, 0)\n}\n\nfn resolve_body_tensor_shape_inner(\n    name: &str,\n    body: &super::onnx_proto::GraphProto,\n    outer_graph: &super::onnx_proto::GraphProto,\n    known_shapes: &HashMap<String, Vec<i64>>,\n    depth: usize,\n) -> Option<Vec<i64>> {\n    if depth > 32 {\n        return None;\n    }\n\n    for vi in body.output.iter().chain(body.value_info.iter()) {\n        if vi.name == name\n            && let Some(shape) = super::onnx_shapes::shape_from_value_info(vi)\n        {\n            return Some(shape);\n        }\n    }\n\n    for init in body\n        .initializer\n        .iter()\n        .chain(outer_graph.initializer.iter())\n    {\n        if init.name == name && !init.dims.is_empty() {\n            return Some(init.dims.to_vec());\n        }\n    }\n\n    if let Some(shape) = known_shapes.get(name) {\n        return Some(shape.clone());\n    }\n\n    let producer = body\n        .node\n        .iter()\n        .find(|n| n.output.contains(&name.to_string()))?;\n    let op = producer.op_type.as_str();\n\n    if super::is_shape_preserving(op) || op == \"Identity\" {\n        let inp = producer.input.first()?;\n        return resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1);\n    }\n\n    if super::is_binary_arithmetic(op) {\n        let resolved: Vec<Vec<i64>> = producer\n            .input\n            .iter()\n            .filter_map(|inp| {\n                resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1)\n            })\n            .collect();\n        let refs: Vec<&Vec<i64>> = resolved.iter().collect();\n        return super::onnx_slicer::broadcast_shapes(&refs);\n    }\n\n    if op == \"Concat\" {\n        let axis = super::onnx_proto::get_attribute_int(producer, \"axis\")?;\n        let input_shapes: Vec<Vec<i64>> = producer\n            .input\n            .iter()\n            .filter_map(|inp| {\n                resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1)\n            })\n            .collect();\n        if input_shapes.len() != producer.input.len() || input_shapes.is_empty() {\n            return None;\n        }\n        let rank = input_shapes[0].len() as i64;\n        if axis < -rank || axis >= rank {\n            return None;\n        }\n        let axis_idx = if axis < 0 {\n            (rank + axis) as usize\n        } else {\n            axis as usize\n        };\n        let mut result = input_shapes[0].clone();\n        for shape in &input_shapes[1..] {\n            if let Some(d) = result.get_mut(axis_idx) {\n                *d += shape.get(axis_idx).copied().unwrap_or(0);\n            }\n        }\n        return Some(result);\n    }\n\n    if op == \"Transpose\" {\n        let inp = producer.input.first()?;\n        let in_shape =\n            resolve_body_tensor_shape_inner(inp, body, outer_graph, known_shapes, depth + 1)?;\n        let perm = &producer.attribute.iter().find(|a| a.name == \"perm\")?.ints;\n        let result: Vec<i64> = perm\n            .iter()\n            .filter_map(|&p| in_shape.get(p as usize).copied())\n            .collect();\n        if result.len() == in_shape.len() {\n            return Some(result);\n        }\n    }\n\n    None\n}\n"
  },
  {
    "path": "crates/dsperse/src/utils/io.rs",
    "content": "use std::collections::HashMap;\nuse std::path::Path;\n\nuse ndarray::{ArrayD, Axis, IxDyn};\nuse rmpv::Value;\n\nuse crate::error::{DsperseError, Result};\n\npub fn read_msgpack(path: &Path) -> Result<Value> {\n    let data = crate::utils::limits::read_checked(path)?;\n    rmp_serde::from_slice(&data).map_err(Into::into)\n}\n\npub fn write_msgpack(path: &Path, value: &Value) -> Result<()> {\n    if let Some(parent) = path.parent() {\n        std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?;\n    }\n    let data = rmp_serde::to_vec_named(value)?;\n    std::fs::write(path, data).map_err(|e| DsperseError::io(e, path))\n}\n\npub fn extract_input_data(value: &Value) -> Option<&Value> {\n    map_get_ref(value, \"input_data\")\n        .or_else(|| map_get_ref(value, \"input\"))\n        .or_else(|| map_get_ref(value, \"data\"))\n        .or_else(|| map_get_ref(value, \"inputs\"))\n}\n\npub fn flatten_nested_list(value: &Value) -> Vec<f64> {\n    let mut result = Vec::new();\n    flatten_recursive(value, &mut result);\n    result\n}\n\nfn flatten_recursive(value: &Value, out: &mut Vec<f64>) {\n    match value {\n        Value::F64(f) => out.push(*f),\n        Value::F32(f) => out.push(*f as f64),\n        Value::Integer(n) => {\n            if let Some(f) = n.as_f64() {\n                out.push(f);\n            } else {\n                tracing::warn!(number = ?n, \"flatten_recursive: dropping non-f64 representable integer\");\n            }\n        }\n        Value::Array(arr) => {\n            for item in arr {\n                flatten_recursive(item, out);\n            }\n        }\n        other => {\n            tracing::warn!(variant = %other, \"flatten_recursive: dropping non-numeric value during flattening\");\n        }\n    }\n}\n\npub fn infer_shape(value: &Value) -> Vec<usize> {\n    let mut shape = Vec::new();\n    let mut current = value;\n    while let Value::Array(arr) = current {\n        shape.push(arr.len());\n        if let Some(first) = arr.first() {\n            current = first;\n        } else {\n            break;\n        }\n    }\n    shape\n}\n\npub fn value_to_arrayd(value: &Value) -> Result<ArrayD<f64>> {\n    let flat = flatten_nested_list(value);\n    let shape = infer_shape(value);\n    if flat.is_empty() {\n        return ArrayD::from_shape_vec(IxDyn(&shape), vec![])\n            .map_err(|e| DsperseError::Pipeline(format!(\"empty arrayd: {e}\")));\n    }\n    if shape.is_empty() && flat.len() == 1 {\n        return ArrayD::from_shape_vec(IxDyn(&[]), flat)\n            .map_err(|e| DsperseError::Pipeline(format!(\"scalar arrayd: {e}\")));\n    }\n    let product: usize = shape.iter().product();\n    if product != flat.len() || shape.is_empty() {\n        tracing::warn!(\n            flat_len = flat.len(),\n            ?shape,\n            product,\n            \"shape mismatch, falling back to 1D\"\n        );\n        return ArrayD::from_shape_vec(IxDyn(&[flat.len()]), flat)\n            .map_err(|e| DsperseError::Pipeline(format!(\"arrayd reshape fallback: {e}\")));\n    }\n    ArrayD::from_shape_vec(IxDyn(&shape), flat)\n        .map_err(|e| DsperseError::Pipeline(format!(\"arrayd reshape: {e}\")))\n}\n\npub fn arrayd_to_value(arr: &ArrayD<f64>) -> Value {\n    match arr.ndim() {\n        0 => Value::F64(arr[IxDyn(&[])]),\n        1 => {\n            let vals: Vec<Value> = arr.iter().map(|&v| Value::F64(v)).collect();\n            Value::Array(vals)\n        }\n        _ => {\n            let vals: Vec<Value> = (0..arr.shape()[0])\n                .map(|i| {\n                    let sub = arr.index_axis(Axis(0), i).to_owned();\n                    arrayd_to_value(&sub)\n                })\n                .collect();\n            Value::Array(vals)\n        }\n    }\n}\n\npub fn gather_inputs_from_cache(\n    cache: &HashMap<String, ArrayD<f64>>,\n    inputs: &[String],\n) -> Result<ArrayD<f64>> {\n    let mut collected = Vec::new();\n    let mut missing = Vec::new();\n    for name in inputs {\n        if let Some(val) = cache.get(name) {\n            collected.push(val.clone());\n        } else {\n            missing.push(name.clone());\n        }\n    }\n    if collected.is_empty() {\n        return Err(DsperseError::Pipeline(format!(\n            \"no cached tensor found for inputs: {inputs:?}\"\n        )));\n    }\n    if !missing.is_empty() {\n        return Err(DsperseError::Pipeline(format!(\n            \"missing tensors in cache: {missing:?} (found {} of {})\",\n            collected.len(),\n            inputs.len()\n        )));\n    }\n    if collected.len() == 1 {\n        return Ok(collected.into_iter().next().unwrap());\n    }\n    if collected[0].ndim() == 0 {\n        return Err(DsperseError::Pipeline(\n            \"cannot concatenate 0-dimensional tensors\".into(),\n        ));\n    }\n    let ref_trailing = collected[0].shape()[1..].to_vec();\n    let ref_product: usize = ref_trailing.iter().product();\n    let batch = collected[0].shape()[0];\n    for (i, arr) in collected.iter_mut().enumerate().skip(1) {\n        let trailing = &arr.shape()[1..];\n        if trailing != ref_trailing.as_slice() {\n            let product: usize = trailing.iter().product();\n            if product == ref_product && arr.shape()[0] == batch {\n                let orig_shape: Vec<usize> = arr.shape().to_vec();\n                let mut target = vec![batch];\n                target.extend_from_slice(&ref_trailing);\n                let owned = std::mem::replace(arr, ArrayD::zeros(ndarray::IxDyn(&[])));\n                *arr = owned\n                    .into_shape_with_order(ndarray::IxDyn(&target))\n                    .map_err(|e| {\n                        DsperseError::Pipeline(format!(\n                            \"gather reshape input {i} from {orig_shape:?} to {target:?}: {e}\",\n                        ))\n                    })?;\n            } else {\n                return Err(DsperseError::Pipeline(format!(\n                    \"shape mismatch at input {}: expected trailing dims {:?}, got {:?}\",\n                    i, ref_trailing, trailing\n                )));\n            }\n        }\n    }\n    ndarray::concatenate(\n        ndarray::Axis(0),\n        &collected.iter().map(|a| a.view()).collect::<Vec<_>>(),\n    )\n    .map_err(|e| DsperseError::Pipeline(format!(\"concat inputs: {e}\")))\n}\n\npub fn build_msgpack_map(entries: Vec<(&str, Value)>) -> Value {\n    Value::Map(\n        entries\n            .into_iter()\n            .map(|(k, v)| (Value::String(k.into()), v))\n            .collect(),\n    )\n}\n\npub fn map_get_ref<'a>(value: &'a Value, key: &str) -> Option<&'a Value> {\n    match value {\n        Value::Map(entries) => entries.iter().find_map(|(k, v)| {\n            if k.as_str().is_some_and(|s| s == key) {\n                Some(v)\n            } else {\n                None\n            }\n        }),\n        _ => None,\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/utils/limits.rs",
    "content": "use std::io::Read;\nuse std::path::Path;\n\nuse crate::error::{DsperseError, Result};\n\npub fn reject_symlink(path: &Path) -> Result<()> {\n    let m = std::fs::symlink_metadata(path).map_err(|e| DsperseError::io(e, path))?;\n    if m.is_symlink() {\n        return Err(DsperseError::Archive(format!(\n            \"symlink not permitted: {}\",\n            path.file_name()\n                .and_then(|n| n.to_str())\n                .unwrap_or(\"<unknown>\")\n        )));\n    }\n    Ok(())\n}\n\nfn open_nofollow(path: &Path) -> Result<std::fs::File> {\n    #[cfg(unix)]\n    {\n        use std::os::unix::fs::OpenOptionsExt;\n        std::fs::OpenOptions::new()\n            .read(true)\n            .custom_flags(libc::O_NOFOLLOW)\n            .open(path)\n            .map_err(|e| {\n                if e.raw_os_error() == Some(libc::ELOOP) {\n                    DsperseError::Archive(format!(\n                        \"symlink not permitted: {}\",\n                        path.file_name()\n                            .and_then(|n| n.to_str())\n                            .unwrap_or(\"<unknown>\")\n                    ))\n                } else {\n                    DsperseError::io(e, path)\n                }\n            })\n    }\n    #[cfg(not(unix))]\n    {\n        reject_symlink(path)?;\n        std::fs::File::open(path).map_err(|e| DsperseError::io(e, path))\n    }\n}\n\npub fn read_checked(path: &Path) -> Result<Vec<u8>> {\n    let mut file = open_nofollow(path)?;\n    let mut buf = Vec::new();\n    file.read_to_end(&mut buf)\n        .map_err(|e| DsperseError::io(e, path))?;\n    Ok(buf)\n}\n\npub fn read_to_string_checked(path: &Path) -> Result<String> {\n    let mut file = open_nofollow(path)?;\n    let mut buf = String::new();\n    file.read_to_string(&mut buf)\n        .map_err(|e| DsperseError::io(e, path))?;\n    Ok(buf)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn reject_symlink_on_regular_file() {\n        let tmp = tempfile::NamedTempFile::new().unwrap();\n        assert!(reject_symlink(tmp.path()).is_ok());\n    }\n\n    #[cfg(unix)]\n    #[test]\n    fn reject_symlink_on_symlink() {\n        let dir = tempfile::tempdir().unwrap();\n        let target = dir.path().join(\"target\");\n        std::fs::write(&target, b\"data\").unwrap();\n        let link = dir.path().join(\"link\");\n        std::os::unix::fs::symlink(&target, &link).unwrap();\n        assert!(reject_symlink(&link).is_err());\n    }\n\n    #[test]\n    fn read_checked_normal() {\n        let tmp = tempfile::NamedTempFile::new().unwrap();\n        std::fs::write(tmp.path(), b\"hello\").unwrap();\n        let data = read_checked(tmp.path()).unwrap();\n        assert_eq!(data, b\"hello\");\n    }\n\n    #[test]\n    fn read_to_string_checked_normal() {\n        let tmp = tempfile::NamedTempFile::new().unwrap();\n        std::fs::write(tmp.path(), \"hello world\").unwrap();\n        let s = read_to_string_checked(tmp.path()).unwrap();\n        assert_eq!(s, \"hello world\");\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/utils/metadata.rs",
    "content": "use std::path::Path;\n\nuse crate::error::{DsperseError, Result};\nuse crate::schema::RunMetadata;\n\npub fn load_run_metadata(path: &Path) -> Result<RunMetadata> {\n    let data = crate::utils::limits::read_checked(path)?;\n    rmp_serde::from_slice(&data).map_err(Into::into)\n}\n\npub fn save_run_metadata(path: &Path, meta: &RunMetadata) -> Result<()> {\n    if let Some(parent) = path.parent() {\n        std::fs::create_dir_all(parent).map_err(|e| DsperseError::io(e, parent))?;\n    }\n    let data = rmp_serde::to_vec_named(meta)?;\n    std::fs::write(path, data).map_err(|e| DsperseError::io(e, path))\n}\n"
  },
  {
    "path": "crates/dsperse/src/utils/mod.rs",
    "content": "pub mod io;\npub mod limits;\npub mod metadata;\npub mod paths;\n"
  },
  {
    "path": "crates/dsperse/src/utils/paths.rs",
    "content": "use std::path::{Component, Path, PathBuf};\n\nuse crate::error::{DsperseError, Result};\n\npub const METADATA_FILE: &str = \"metadata.msgpack\";\npub const INPUT_FILE: &str = \"input.msgpack\";\npub const OUTPUT_FILE: &str = \"output.msgpack\";\npub const WITNESS_FILE: &str = \"witness.bin\";\npub const PROOF_FILE: &str = \"proof.bin\";\n\npub fn resolve_relative_path(base: &Path, relative: &str) -> Result<PathBuf> {\n    let rel = Path::new(relative);\n    if rel.is_absolute() {\n        return Err(DsperseError::Archive(format!(\n            \"absolute path in metadata is not permitted: {relative}\"\n        )));\n    }\n    for component in rel.components() {\n        match component {\n            Component::ParentDir => {\n                return Err(DsperseError::Archive(format!(\n                    \"path traversal component in metadata is not permitted: {relative}\"\n                )));\n            }\n            Component::RootDir | Component::Prefix(_) => {\n                return Err(DsperseError::Archive(format!(\n                    \"invalid path component in metadata: {relative}\"\n                )));\n            }\n            _ => {}\n        }\n    }\n    Ok(base.join(rel))\n}\n\npub fn relativize_path(path: &Path, base: &Path) -> String {\n    path.strip_prefix(base)\n        .map(|p| p.to_string_lossy().to_string())\n        .unwrap_or_else(|_| path.to_string_lossy().to_string())\n}\n\npub fn slice_dir_path(root: &Path, index: usize) -> PathBuf {\n    root.join(format!(\"slice_{index}\"))\n}\n\npub fn find_metadata_path(dir: &Path) -> Option<PathBuf> {\n    let direct = dir.join(METADATA_FILE);\n    if direct.exists() {\n        return Some(direct);\n    }\n    let slices = dir.join(\"slices\").join(METADATA_FILE);\n    if slices.exists() {\n        return Some(slices);\n    }\n    None\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn resolve_relative_normal_path() {\n        let base = Path::new(\"/tmp/slices\");\n        let result = resolve_relative_path(base, \"payload/model.onnx\").unwrap();\n        assert_eq!(result, PathBuf::from(\"/tmp/slices/payload/model.onnx\"));\n    }\n\n    #[test]\n    fn resolve_relative_rejects_absolute() {\n        let base = Path::new(\"/tmp/slices\");\n        assert!(resolve_relative_path(base, \"/etc/passwd\").is_err());\n    }\n\n    #[test]\n    fn resolve_relative_rejects_parent_dir() {\n        let base = Path::new(\"/tmp/slices\");\n        assert!(resolve_relative_path(base, \"../../../etc/passwd\").is_err());\n    }\n\n    #[test]\n    fn resolve_relative_rejects_embedded_parent() {\n        let base = Path::new(\"/tmp/slices\");\n        assert!(resolve_relative_path(base, \"payload/../../../etc/passwd\").is_err());\n    }\n\n    #[test]\n    fn resolve_relative_allows_current_dir() {\n        let base = Path::new(\"/tmp/slices\");\n        let result = resolve_relative_path(base, \"./model.onnx\").unwrap();\n        assert_eq!(result, PathBuf::from(\"/tmp/slices/./model.onnx\"));\n    }\n\n    #[test]\n    fn resolve_relative_empty_string() {\n        let base = Path::new(\"/tmp/slices\");\n        let result = resolve_relative_path(base, \"\").unwrap();\n        assert_eq!(result, PathBuf::from(\"/tmp/slices/\"));\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/src/version.rs",
    "content": "use serde::{Deserialize, Serialize};\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct DsperseVersion {\n    pub dsperse_version: String,\n    pub dsperse_rev: Option<String>,\n    pub jstprove_version: String,\n    pub jstprove_rev: Option<String>,\n}\n\npub fn dsperse_artifact_version() -> DsperseVersion {\n    let jst_ver = jstprove_circuits::api::jstprove_artifact_version();\n    DsperseVersion {\n        dsperse_version: env!(\"CARGO_PKG_VERSION\").to_string(),\n        dsperse_rev: option_env!(\"DSPERSE_GIT_REV\").map(String::from),\n        jstprove_version: jst_ver.crate_version,\n        jstprove_rev: Some(jst_ver.git_rev),\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/tests/integration_slice.rs",
    "content": "use std::path::Path;\n\nuse dsperse::schema::metadata::ModelMetadata;\n\nfn test_models_dir() -> &'static Path {\n    Path::new(concat!(env!(\"CARGO_MANIFEST_DIR\"), \"/../../tests/models\"))\n}\n\n#[test]\nfn slice_net_model() {\n    let model_path = test_models_dir().join(\"net/model.onnx\");\n    assert!(\n        model_path.exists(),\n        \"test model not found at {}\",\n        model_path.display()\n    );\n\n    let tmp = tempfile::tempdir().expect(\"create temp dir\");\n    let output_dir = tmp.path().join(\"slices\");\n\n    let metadata = dsperse::slicer::slice_model(\n        &model_path,\n        Some(&output_dir),\n        None,\n        jstprove_circuits::ProofSystem::Expander.supported_ops(),\n        None,\n    )\n    .expect(\"slice_model\");\n\n    assert!(!metadata.slices.is_empty());\n    assert_eq!(metadata.model_type, \"ONNX\");\n    assert!(!metadata.input_shape.is_empty());\n    assert!(!metadata.output_shapes.is_empty());\n\n    let meta_path = output_dir.join(\"metadata.msgpack\");\n    assert!(meta_path.exists(), \"metadata.msgpack must be written\");\n\n    let model_onnx = output_dir.join(\"model.onnx\");\n    assert!(\n        model_onnx.exists(),\n        \"model.onnx must be copied to output dir\"\n    );\n\n    let loaded = ModelMetadata::load(&meta_path).expect(\"load metadata\");\n    assert_eq!(loaded.slices.len(), metadata.slices.len());\n\n    assert!(loaded.traced_shapes.is_some());\n    assert!(loaded.original_model_path.is_some());\n    assert_eq!(loaded.original_model_path.as_deref(), Some(\"model.onnx\"));\n}\n\n#[test]\nfn slice_doom_model() {\n    let model_path = test_models_dir().join(\"doom/model.onnx\");\n    assert!(\n        model_path.exists(),\n        \"test model not found at {}\",\n        model_path.display()\n    );\n\n    let tmp = tempfile::tempdir().expect(\"create temp dir\");\n    let output_dir = tmp.path().join(\"slices\");\n\n    let metadata = dsperse::slicer::slice_model(\n        &model_path,\n        Some(&output_dir),\n        None,\n        jstprove_circuits::ProofSystem::Expander.supported_ops(),\n        None,\n    )\n    .expect(\"slice_model\");\n\n    assert!(!metadata.slices.is_empty());\n\n    for (i, slice) in metadata.slices.iter().enumerate() {\n        assert_eq!(slice.index, i);\n        assert!(!slice.dependencies.input.is_empty());\n        assert!(!slice.dependencies.output.is_empty());\n    }\n}\n\n#[test]\nfn slice_net_model_remainder() {\n    let model_path = test_models_dir().join(\"net/model.onnx\");\n    assert!(\n        model_path.exists(),\n        \"test model not found at {}\",\n        model_path.display()\n    );\n\n    let tmp = tempfile::tempdir().expect(\"create temp dir\");\n    let output_dir = tmp.path().join(\"slices\");\n\n    let metadata = dsperse::slicer::slice_model(\n        &model_path,\n        Some(&output_dir),\n        None,\n        jstprove_circuits::ProofSystem::Remainder.supported_ops(),\n        None,\n    )\n    .expect(\"slice_model with Remainder\");\n\n    assert!(!metadata.slices.is_empty());\n    assert_eq!(metadata.model_type, \"ONNX\");\n}\n\n#[test]\nfn slice_with_tile_size() {\n    let model_path = test_models_dir().join(\"net/model.onnx\");\n    assert!(\n        model_path.exists(),\n        \"test model not found at {}\",\n        model_path.display()\n    );\n\n    let tmp = tempfile::tempdir().expect(\"create temp dir\");\n    let output_dir = tmp.path().join(\"slices\");\n\n    let metadata = dsperse::slicer::slice_model(\n        &model_path,\n        Some(&output_dir),\n        Some(8),\n        jstprove_circuits::ProofSystem::Expander.supported_ops(),\n        None,\n    )\n    .expect(\"slice_model\");\n\n    assert!(!metadata.slices.is_empty());\n\n    let meta_path = output_dir.join(\"metadata.msgpack\");\n    assert!(meta_path.exists());\n}\n\n#[test]\nfn slice_metadata_roundtrip_from_disk() {\n    let model_path = test_models_dir().join(\"net/model.onnx\");\n    assert!(\n        model_path.exists(),\n        \"test model not found at {}\",\n        model_path.display()\n    );\n\n    let tmp = tempfile::tempdir().expect(\"create temp dir\");\n    let output_dir = tmp.path().join(\"slices\");\n\n    let original = dsperse::slicer::slice_model(\n        &model_path,\n        Some(&output_dir),\n        None,\n        jstprove_circuits::ProofSystem::Expander.supported_ops(),\n        None,\n    )\n    .expect(\"slice_model\");\n\n    let meta_path = output_dir.join(\"metadata.msgpack\");\n    let deserialized = ModelMetadata::load(&meta_path).expect(\"load metadata\");\n\n    assert_eq!(original.slices.len(), deserialized.slices.len());\n    assert_eq!(original.original_model, deserialized.original_model);\n    assert_eq!(original.input_shape, deserialized.input_shape);\n    assert_eq!(original.output_shapes, deserialized.output_shapes);\n    assert_eq!(original.traced_shapes, deserialized.traced_shapes);\n}\n\n#[test]\nfn materialize_from_manifest() {\n    let model_path = test_models_dir().join(\"net/model.onnx\");\n    assert!(\n        model_path.exists(),\n        \"test model not found at {}\",\n        model_path.display()\n    );\n\n    let tmp = tempfile::tempdir().expect(\"create temp dir\");\n    let output_dir = tmp.path().join(\"slices\");\n\n    let metadata = dsperse::slicer::slice_model(\n        &model_path,\n        Some(&output_dir),\n        None,\n        jstprove_circuits::ProofSystem::Expander.supported_ops(),\n        None,\n    )\n    .expect(\"slice_model\");\n\n    dsperse::slicer::materializer::ensure_all_slices_materialized(&output_dir, &metadata)\n        .expect(\"materialize all slices\");\n\n    for slice in &metadata.slices {\n        let slice_dir = output_dir.join(format!(\"slice_{}\", slice.index));\n        assert!(\n            slice_dir.exists(),\n            \"slice dir must exist after materialization: {}\",\n            slice_dir.display()\n        );\n\n        let payload_dir = slice_dir.join(\"payload\");\n        assert!(payload_dir.exists(), \"payload dir must exist\");\n\n        let onnx_file = payload_dir.join(&slice.filename);\n        assert!(\n            onnx_file.exists(),\n            \"onnx file must exist: {}\",\n            onnx_file.display()\n        );\n    }\n}\n\n#[test]\nfn resolve_onnx_points_to_existing_file_after_materialize() {\n    let model_path = test_models_dir().join(\"net/model.onnx\");\n    assert!(\n        model_path.exists(),\n        \"test model not found at {}\",\n        model_path.display()\n    );\n\n    let tmp = tempfile::tempdir().expect(\"create temp dir\");\n    let output_dir = tmp.path().join(\"slices\");\n\n    let metadata = dsperse::slicer::slice_model(\n        &model_path,\n        Some(&output_dir),\n        None,\n        jstprove_circuits::ProofSystem::Expander.supported_ops(),\n        None,\n    )\n    .expect(\"slice_model\");\n\n    dsperse::slicer::materializer::ensure_all_slices_materialized(&output_dir, &metadata)\n        .expect(\"materialize all slices\");\n\n    let loaded = ModelMetadata::load(&output_dir.join(\"metadata.msgpack\")).expect(\"load metadata\");\n    assert!(!loaded.slices.is_empty());\n\n    for slice in &loaded.slices {\n        let resolved = slice.resolve_onnx(&output_dir).unwrap();\n        assert!(\n            resolved.is_file(),\n            \"resolve_onnx for slice {} must point to a regular file, got: {}\",\n            slice.index,\n            resolved.display()\n        );\n        assert!(\n            resolved.starts_with(&output_dir),\n            \"resolved path must start with output_dir\"\n        );\n        let resolved_path = resolved.to_string_lossy();\n        let output_str = output_dir.to_string_lossy();\n        let count = resolved_path.matches(output_str.as_ref()).count();\n        assert_eq!(\n            count, 1,\n            \"output_dir must appear exactly once in resolved path, got {count}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/dsperse/tests/schema_roundtrip.rs",
    "content": "use std::path::Path;\n\nuse dsperse::schema::*;\n\n#[test]\nfn model_metadata_roundtrip() {\n    let json = r#\"{\n        \"original_model\": \"model.onnx\",\n        \"model_type\": \"onnx\",\n        \"input_shape\": [[1, 3, 32, 32]],\n        \"output_shapes\": [[1, 10]],\n        \"slice_points\": [2, 5],\n        \"slices\": [\n            {\n                \"index\": 0,\n                \"filename\": \"slice_0.onnx\",\n                \"path\": \"/tmp/slices/slice_0/payload/slice_0.onnx\",\n                \"relative_path\": \"slice_0/payload/slice_0.onnx\",\n                \"shape\": {\n                    \"tensor_shape\": {\n                        \"input\": [[1, 3, 32, 32]],\n                        \"output\": [[1, 16, 16, 16]]\n                    }\n                },\n                \"dependencies\": {\n                    \"input\": [\"input\"],\n                    \"output\": [\"conv1_out\"],\n                    \"filtered_inputs\": [\"input\"]\n                },\n                \"compilation\": {\n                    \"jstprove\": {\n                        \"compiled\": true,\n                        \"tiled\": false,\n                        \"weights_as_inputs\": false,\n                        \"files\": {\n                            \"compiled\": \"jstprove/circuit.txt\",\n                            \"settings\": \"jstprove/settings.json\"\n                        }\n                    }\n                }\n            },\n            {\n                \"index\": 1,\n                \"filename\": \"slice_1.onnx\",\n                \"path\": \"/tmp/slices/slice_1/payload/slice_1.onnx\",\n                \"relative_path\": \"slice_1/payload/slice_1.onnx\",\n                \"shape\": {\n                    \"tensor_shape\": {\n                        \"input\": [[1, 16, 16, 16]],\n                        \"output\": [[1, 10]]\n                    }\n                },\n                \"dependencies\": {\n                    \"input\": [\"conv1_out\"],\n                    \"output\": [\"output\"],\n                    \"filtered_inputs\": [\"conv1_out\"]\n                },\n                \"tiling\": {\n                    \"slice_idx\": 1,\n                    \"tile_size\": 8,\n                    \"num_tiles\": 4,\n                    \"tiles_y\": 2,\n                    \"tiles_x\": 2,\n                    \"halo\": [1, 1],\n                    \"out_tile\": [8, 8],\n                    \"stride\": [1, 1],\n                    \"c_in\": 16,\n                    \"c_out\": 32,\n                    \"input_name\": \"conv1_out\",\n                    \"output_name\": \"conv2_out\",\n                    \"tile\": {\n                        \"path\": \"tiles/tile.onnx\",\n                        \"conv_out\": [8, 8]\n                    },\n                    \"tiles\": [\n                        {\"path\": \"tiles/tile.onnx\", \"conv_out\": [8, 8]},\n                        {\"path\": \"tiles/tile.onnx\", \"conv_out\": [8, 8]}\n                    ]\n                },\n                \"compilation\": {\n                    \"jstprove\": {\n                        \"compiled\": false,\n                        \"tiled\": false,\n                        \"weights_as_inputs\": false,\n                        \"files\": {}\n                    }\n                }\n            }\n        ]\n    }\"#;\n\n    let meta: ModelMetadata = serde_json::from_str(json).unwrap();\n    assert_eq!(meta.original_model, \"model.onnx\");\n    assert_eq!(meta.slices.len(), 2);\n    assert_eq!(meta.slice_points, vec![2, 5]);\n\n    let s0 = &meta.slices[0];\n    assert_eq!(s0.index, 0);\n    assert!(s0.compilation.jstprove.compiled);\n    assert_eq!(\n        s0.compilation.jstprove.files.compiled.as_deref(),\n        Some(\"jstprove/circuit.txt\")\n    );\n    assert!(s0.tiling.is_none());\n\n    let s1 = &meta.slices[1];\n    assert!(s1.tiling.is_some());\n    let tiling = s1.tiling.as_ref().unwrap();\n    assert_eq!(tiling.num_tiles, 4);\n    assert_eq!(tiling.halo, [1, 1, 1, 1]);\n    assert_eq!(tiling.tiles.as_ref().unwrap().len(), 2);\n\n    let msgpack_bytes = rmp_serde::to_vec_named(&meta).unwrap();\n    let meta2: ModelMetadata = rmp_serde::from_slice(&msgpack_bytes).unwrap();\n    assert_eq!(meta2.slices.len(), 2);\n    assert_eq!(meta2.slices[0].index, 0);\n}\n\n#[test]\nfn run_metadata_roundtrip() {\n    let json = r#\"{\n        \"slices\": {\n            \"slice_0\": {\n                \"path\": \"slice_0/payload/slice_0.onnx\",\n                \"input_shape\": [[1, 3, 32, 32]],\n                \"output_shape\": [[1, 16, 16, 16]],\n                \"dependencies\": {\n                    \"input\": [\"input\"],\n                    \"output\": [\"conv1_out\"],\n                    \"filtered_inputs\": [\"input\"]\n                },\n                \"backend\": \"jstprove\",\n                \"circuit_path\": \"slice_0/payload/jstprove/circuit.txt\"\n            }\n        },\n        \"execution_chain\": {\n            \"head\": \"slice_0\",\n            \"nodes\": {\n                \"slice_0\": {\n                    \"slice_id\": \"slice_0\",\n                    \"primary\": \"slice_0/payload/jstprove/circuit.txt\",\n                    \"fallbacks\": [\"slice_0/payload/slice_0.onnx\"],\n                    \"use_circuit\": true,\n                    \"next\": null,\n                    \"circuit_path\": \"slice_0/payload/jstprove/circuit.txt\",\n                    \"onnx_path\": \"slice_0/payload/slice_0.onnx\",\n                    \"backend\": \"jstprove\"\n                }\n            },\n            \"fallback_map\": {},\n            \"execution_results\": [],\n            \"jstprove_proved_slices\": 0,\n            \"jstprove_verified_slices\": 0\n        }\n    }\"#;\n\n    let meta: RunMetadata = serde_json::from_str(json).unwrap();\n    assert_eq!(meta.slices.len(), 1);\n\n    let slice = meta.get_slice(\"slice_0\").unwrap();\n    assert_eq!(slice.backend, BackendKind::Jstprove);\n    assert_eq!(\n        slice.jstprove_circuit_path.as_deref(),\n        Some(\"slice_0/payload/jstprove/circuit.txt\")\n    );\n\n    let chain = &meta.execution_chain;\n    assert_eq!(chain.head.as_deref(), Some(\"slice_0\"));\n    assert!(chain.nodes[\"slice_0\"].use_circuit);\n\n    let circuit_slices: Vec<_> = meta.iter_circuit_slices().collect();\n    assert_eq!(circuit_slices.len(), 1);\n    assert_eq!(circuit_slices[0].0, \"slice_0\");\n\n    let msgpack_bytes = rmp_serde::to_vec_named(&meta).unwrap();\n    let meta2: RunMetadata = rmp_serde::from_slice(&msgpack_bytes).unwrap();\n    assert_eq!(meta2.slices.len(), 1);\n}\n\n#[test]\nfn execution_info_with_tiles() {\n    let json = r#\"{\n        \"method\": \"tiled\",\n        \"success\": true,\n        \"tile_exec_infos\": [\n            {\"tile_idx\": 0, \"success\": true, \"method\": \"jstprove_gen_witness\", \"time_sec\": 1.5},\n            {\"tile_idx\": 1, \"success\": true, \"method\": \"jstprove_gen_witness\", \"time_sec\": 1.3},\n            {\"tile_idx\": 2, \"success\": false, \"error\": \"timeout\", \"time_sec\": 30.0}\n        ]\n    }\"#;\n\n    let info: ExecutionInfo = serde_json::from_str(json).unwrap();\n    assert!(info.success);\n    assert_eq!(info.tile_exec_infos.len(), 3);\n    assert!(!info.tile_exec_infos[2].success);\n    assert_eq!(info.tile_exec_infos[2].error.as_deref(), Some(\"timeout\"));\n}\n\n#[test]\nfn channel_split_roundtrip() {\n    let json = r#\"{\n        \"slice_idx\": 2,\n        \"c_in\": 64,\n        \"c_out\": 128,\n        \"num_groups\": 4,\n        \"channels_per_group\": 16,\n        \"input_name\": \"relu1_out\",\n        \"output_name\": \"conv2_out\",\n        \"h\": 16,\n        \"w\": 16,\n        \"groups\": [\n            {\"group_idx\": 0, \"c_start\": 0, \"c_end\": 16, \"path\": \"channel_groups/group_0.onnx\"},\n            {\"group_idx\": 1, \"c_start\": 16, \"c_end\": 32, \"path\": \"channel_groups/group_1.onnx\"}\n        ],\n        \"bias_path\": \"channel_groups/bias.msgpack\"\n    }\"#;\n\n    let info: ChannelSplitInfo = serde_json::from_str(json).unwrap();\n    assert_eq!(info.num_groups, 4);\n    assert_eq!(info.groups.len(), 2);\n    assert_eq!(info.groups[0].c_end, 16);\n    assert_eq!(\n        info.bias_path.as_deref(),\n        Some(\"channel_groups/bias.msgpack\")\n    );\n\n    let msgpack_bytes = rmp_serde::to_vec_named(&info).unwrap();\n    let info2: ChannelSplitInfo = rmp_serde::from_slice(&msgpack_bytes).unwrap();\n    assert_eq!(info2.num_groups, 4);\n}\n\n#[test]\nfn compilation_files_aliases() {\n    let json1 = r#\"{\"compiled\": \"circuit.txt\"}\"#;\n    let json2 = r#\"{\"compiled_circuit\": \"circuit.txt\"}\"#;\n    let json3 = r#\"{\"circuit\": \"circuit.txt\"}\"#;\n\n    let f1: CompilationFiles = serde_json::from_str(json1).unwrap();\n    let f2: CompilationFiles = serde_json::from_str(json2).unwrap();\n    let f3: CompilationFiles = serde_json::from_str(json3).unwrap();\n\n    assert_eq!(f1.compiled.as_deref(), Some(\"circuit.txt\"));\n    assert_eq!(f2.compiled.as_deref(), Some(\"circuit.txt\"));\n    assert_eq!(f3.compiled.as_deref(), Some(\"circuit.txt\"));\n}\n\n#[test]\nfn backend_serde() {\n    assert_eq!(\n        serde_json::to_string(&BackendKind::Jstprove).unwrap(),\n        r#\"\"jstprove\"\"#\n    );\n    assert_eq!(\n        serde_json::to_string(&BackendKind::Onnx).unwrap(),\n        r#\"\"onnx\"\"#\n    );\n\n    let b: BackendKind = serde_json::from_str(r#\"\"jstprove\"\"#).unwrap();\n    assert_eq!(b, BackendKind::Jstprove);\n\n    let b: BackendKind = serde_json::from_str(r#\"\"JSTPROVE\"\"#).unwrap();\n    assert_eq!(b, BackendKind::Jstprove);\n}\n\n#[test]\nfn tensor_shape_i64_deserialization() {\n    let json = r#\"{\n        \"input\": [[1, 3, 224, 224]],\n        \"output\": [[1, 1000]]\n    }\"#;\n\n    let shape: TensorShape = serde_json::from_str(json).unwrap();\n    assert_eq!(shape.input, vec![vec![1i64, 3, 224, 224]]);\n    assert_eq!(shape.output, vec![vec![1i64, 1000]]);\n\n    let msgpack_bytes = rmp_serde::to_vec_named(&shape).unwrap();\n    let shape2: TensorShape = rmp_serde::from_slice(&msgpack_bytes).unwrap();\n    assert_eq!(shape2.input, shape.input);\n    assert_eq!(shape2.output, shape.output);\n}\n\n#[test]\nfn tensor_shape_rejects_non_integer() {\n    let json = r#\"{\"input\": [[1, \"hello\", 3]], \"output\": []}\"#;\n    let result: std::result::Result<TensorShape, _> = serde_json::from_str(json);\n    assert!(result.is_err());\n}\n\n#[test]\nfn run_slice_metadata_i64_shapes() {\n    let json = r#\"{\n        \"path\": \"slice_0/payload/slice_0.onnx\",\n        \"input_shape\": [[1, 3, 32, 32]],\n        \"output_shape\": [[1, 16, 16, 16]],\n        \"dependencies\": {\n            \"input\": [\"input\"],\n            \"output\": [\"conv1_out\"],\n            \"filtered_inputs\": [\"input\"]\n        },\n        \"backend\": \"onnx\"\n    }\"#;\n\n    let meta: RunSliceMetadata = serde_json::from_str(json).unwrap();\n    assert_eq!(meta.input_shape, vec![vec![1i64, 3, 32, 32]]);\n    assert_eq!(meta.output_shape, vec![vec![1i64, 16, 16, 16]]);\n\n    let msgpack_bytes = rmp_serde::to_vec_named(&meta).unwrap();\n    let meta2: RunSliceMetadata = rmp_serde::from_slice(&msgpack_bytes).unwrap();\n    assert_eq!(meta2.input_shape, meta.input_shape);\n    assert_eq!(meta2.output_shape, meta.output_shape);\n}\n\n#[test]\nfn resolve_onnx_uses_relative_path_not_absolute() {\n    let json = r#\"{\n        \"index\": 0,\n        \"filename\": \"slice_0.onnx\",\n        \"path\": \"/original/cwd/slices/slice_0/payload/slice_0.onnx\",\n        \"relative_path\": \"slice_0/payload/slice_0.onnx\",\n        \"shape\": {\"tensor_shape\": {\"input\": [[1, 3, 32, 32]], \"output\": [[1, 10]]}},\n        \"dependencies\": {\"input\": [], \"output\": [], \"filtered_inputs\": []},\n        \"compilation\": {\"jstprove\": {\"compiled\": false, \"tiled\": false, \"weights_as_inputs\": false, \"files\": {}}}\n    }\"#;\n\n    let slice: SliceMetadata = serde_json::from_str(json).unwrap();\n    let slices_dir = Path::new(\"/relocated/slices\");\n    let resolved = slice.resolve_onnx(slices_dir).unwrap();\n\n    assert_eq!(\n        resolved,\n        Path::new(\"/relocated/slices/slice_0/payload/slice_0.onnx\"),\n        \"resolve_onnx must use relative_path (relative to slices_dir), not the absolute path field\"\n    );\n\n    assert!(\n        !resolved.to_string_lossy().contains(\"/original/\"),\n        \"resolved path must not contain the original CWD-relative path\"\n    );\n}\n"
  },
  {
    "path": "crates/dsperse/tests/sn2_contract.rs",
    "content": "use std::path::Path;\n\nuse ndarray::{ArrayD, IxDyn};\nuse rmpv::Value;\n\nfn make_value_array(vals: &[f64]) -> Value {\n    Value::Array(vals.iter().map(|&v| Value::F64(v)).collect())\n}\n\nfn make_value_2d(rows: &[&[f64]]) -> Value {\n    Value::Array(rows.iter().map(|row| make_value_array(row)).collect())\n}\n\nfn make_value_3d(planes: &[&[&[f64]]]) -> Value {\n    Value::Array(planes.iter().map(|plane| make_value_2d(plane)).collect())\n}\n\nfn make_value_4d(blocks: &[&[&[&[f64]]]]) -> Value {\n    Value::Array(blocks.iter().map(|block| make_value_3d(block)).collect())\n}\n\n#[test]\nfn value_arrayd_roundtrip_1d() {\n    let input = make_value_array(&[1.0, 2.0, 3.0, 4.0]);\n    let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap();\n    assert_eq!(arr.shape(), &[4]);\n    assert_eq!(arr[IxDyn(&[0])], 1.0);\n    assert_eq!(arr[IxDyn(&[3])], 4.0);\n\n    let output = dsperse::utils::io::arrayd_to_value(&arr);\n    assert_eq!(output, input);\n}\n\n#[test]\nfn value_arrayd_roundtrip_2d() {\n    let input = make_value_2d(&[&[1.0, 2.0], &[3.0, 4.0]]);\n    let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap();\n    assert_eq!(arr.shape(), &[2, 2]);\n    assert_eq!(arr[IxDyn(&[0, 0])], 1.0);\n    assert_eq!(arr[IxDyn(&[1, 1])], 4.0);\n\n    let output = dsperse::utils::io::arrayd_to_value(&arr);\n    assert_eq!(output, input);\n}\n\n#[test]\nfn value_arrayd_roundtrip_3d() {\n    let input = make_value_3d(&[&[&[1.0, 2.0], &[3.0, 4.0]], &[&[5.0, 6.0], &[7.0, 8.0]]]);\n    let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap();\n    assert_eq!(arr.shape(), &[2, 2, 2]);\n    assert_eq!(arr[IxDyn(&[0, 0, 0])], 1.0);\n    assert_eq!(arr[IxDyn(&[1, 1, 1])], 8.0);\n\n    let output = dsperse::utils::io::arrayd_to_value(&arr);\n    assert_eq!(output, input);\n}\n\n#[test]\nfn value_arrayd_roundtrip_4d() {\n    let input = make_value_4d(&[&[&[&[0.5, 1.5], &[2.5, 3.5]], &[&[4.5, 5.5], &[6.5, 7.5]]]]);\n    let arr = dsperse::utils::io::value_to_arrayd(&input).unwrap();\n    assert_eq!(arr.shape(), &[1, 2, 2, 2]);\n    assert_eq!(arr[IxDyn(&[0, 0, 0, 0])], 0.5);\n    assert_eq!(arr[IxDyn(&[0, 1, 1, 1])], 7.5);\n\n    let output = dsperse::utils::io::arrayd_to_value(&arr);\n    assert_eq!(output, input);\n}\n\n#[test]\nfn value_arrayd_full_roundtrip_preserves_values() {\n    let original = make_value_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);\n    let arr = dsperse::utils::io::value_to_arrayd(&original).unwrap();\n    let reconstructed = dsperse::utils::io::arrayd_to_value(&arr);\n    let arr2 = dsperse::utils::io::value_to_arrayd(&reconstructed).unwrap();\n\n    assert_eq!(arr.shape(), arr2.shape());\n    assert_eq!(arr, arr2);\n    assert_eq!(original, reconstructed);\n}\n\n#[test]\nfn extract_input_data_key_precedence() {\n    let val = Value::Map(vec![\n        (Value::String(\"input_data\".into()), make_value_array(&[1.0])),\n        (Value::String(\"input\".into()), make_value_array(&[2.0])),\n        (Value::String(\"data\".into()), make_value_array(&[3.0])),\n        (Value::String(\"inputs\".into()), make_value_array(&[4.0])),\n    ]);\n    let extracted = dsperse::utils::io::extract_input_data(&val).unwrap();\n    assert_eq!(extracted, &make_value_array(&[1.0]));\n}\n\n#[test]\nfn extract_input_data_fallback_to_input() {\n    let val = Value::Map(vec![\n        (Value::String(\"input\".into()), make_value_array(&[2.0])),\n        (Value::String(\"data\".into()), make_value_array(&[3.0])),\n        (Value::String(\"inputs\".into()), make_value_array(&[4.0])),\n    ]);\n    let extracted = dsperse::utils::io::extract_input_data(&val).unwrap();\n    assert_eq!(extracted, &make_value_array(&[2.0]));\n}\n\n#[test]\nfn extract_input_data_fallback_to_data() {\n    let val = Value::Map(vec![\n        (Value::String(\"data\".into()), make_value_array(&[3.0])),\n        (Value::String(\"inputs\".into()), make_value_array(&[4.0])),\n    ]);\n    let extracted = dsperse::utils::io::extract_input_data(&val).unwrap();\n    assert_eq!(extracted, &make_value_array(&[3.0]));\n}\n\n#[test]\nfn extract_input_data_fallback_to_inputs() {\n    let val = Value::Map(vec![(\n        Value::String(\"inputs\".into()),\n        make_value_array(&[4.0]),\n    )]);\n    let extracted = dsperse::utils::io::extract_input_data(&val).unwrap();\n    assert_eq!(extracted, &make_value_array(&[4.0]));\n}\n\n#[test]\nfn extract_input_data_returns_none_for_unrecognized_keys() {\n    let val = Value::Map(vec![\n        (Value::String(\"tensor\".into()), make_value_array(&[1.0])),\n        (Value::String(\"x\".into()), make_value_array(&[2.0])),\n    ]);\n    assert!(dsperse::utils::io::extract_input_data(&val).is_none());\n}\n\n#[test]\nfn slice_dir_path_formats_correctly() {\n    let root = Path::new(\"/some/root\");\n    assert_eq!(\n        dsperse::utils::paths::slice_dir_path(root, 0),\n        Path::new(\"/some/root/slice_0\")\n    );\n    assert_eq!(\n        dsperse::utils::paths::slice_dir_path(root, 5),\n        Path::new(\"/some/root/slice_5\")\n    );\n    assert_eq!(\n        dsperse::utils::paths::slice_dir_path(root, 42),\n        Path::new(\"/some/root/slice_42\")\n    );\n}\n\n#[test]\nfn arrayd_to_value_then_extract_input_data_integration() {\n    let arr = ArrayD::from_shape_vec(IxDyn(&[1, 3]), vec![1.0, 2.0, 3.0]).unwrap();\n    let tensor_val = dsperse::utils::io::arrayd_to_value(&arr);\n    let wrapped = Value::Map(vec![(Value::String(\"input_data\".into()), tensor_val)]);\n\n    let extracted = dsperse::utils::io::extract_input_data(&wrapped).unwrap();\n    let roundtripped = dsperse::utils::io::value_to_arrayd(extracted).unwrap();\n    assert_eq!(arr.shape(), roundtripped.shape());\n    assert_eq!(arr, roundtripped);\n}\n"
  },
  {
    "path": "deny.toml",
    "content": "[graph]\ntargets = []\nall-features = false\n\n[advisories]\nyanked = \"warn\"\n\n[bans]\nmultiple-versions = \"warn\"\nwildcards = \"warn\"\n\n[sources]\nunknown-registry = \"deny\"\nunknown-git = \"warn\"\nallow-git = [\n    \"https://github.com/inference-labs-inc/JSTprove.git\",\n]\n"
  },
  {
    "path": "docs/JSTPROVE_BACKEND.md",
    "content": "# JSTprove Backend Integration\n\n## Overview\n\nDSperse uses [JSTprove](https://github.com/inference-labs-inc/JSTprove) as its ZK proving backend. JSTprove is integrated as a Rust library dependency (`jstprove_circuits` crate) linked at compile time — there is no external CLI or Python process involved.\n\nDSperse is proving-system-agnostic. JSTprove currently provides two proof system backends selectable via the `--proof-system` flag:\n\n| Proof System | Description |\n|--------------|-------------|\n| `expander` (default) | Expander-based proving system |\n| `remainder` | Remainder-based proving system |\n\n## Architecture\n\nThe integration lives in `crates/dsperse/src/backend/jstprove.rs`, which wraps the `jstprove_circuits` crate. The `JstproveBackend` struct exposes the following operations that map directly to the pipeline stages:\n\n| Pipeline Stage | JSTprove Function | Description |\n|----------------|-------------------|-------------|\n| Compile | `compile_bn254` | Compiles an ONNX slice into a BN254 circuit (msgpack bundle) |\n| Witness | `witness_bn254` / `witness_bn254_from_f64` | Generates a witness from JSON or raw f64 inputs |\n| Prove | `prove_bn254` | Generates a proof from a compiled circuit and witness |\n| Verify | `verify_bn254` | Verifies a proof against a circuit and witness |\n| Extract | `extract_outputs_bn254` | Extracts model outputs from a witness |\n\nCircuit compilation produces a msgpack bundle containing the circuit, witness solver, and optional metadata (`CircuitParams`). All subsequent operations load this bundle via `read_circuit_msgpack`.\n\n## Proof Pipeline Flow\n\n```text\nONNX slice\n    |\n    v\ncompile_bn254 --> compiled circuit bundle (.msgpack)\n    |\n    v\nwitness_bn254 --> witness bytes\n    |\n    v\nprove_bn254 --> proof bytes\n    |\n    v\nverify_bn254 --> bool (valid/invalid)\n```\n\n## Proof System Selection\n\nThe `--proof-system` flag is available on `slice`, `compile`, and `full-run` subcommands. Each proof system defines its own set of supported ONNX operations, queryable via `ProofSystem::supported_ops()`. The `--circuit-ops` flag allows restricting compilation to a subset of supported ops.\n\n```bash\ndsperse slice --model-dir models/net --proof-system expander\ndsperse compile --model-dir models/net --proof-system remainder\ndsperse full-run --model-dir models/net --proof-system expander --circuit-ops \"MatMul,Relu\"\n```\n\n## Dependency\n\nJSTprove is pulled in as a Cargo git dependency via the `jstprove_circuits` crate. No separate installation step is required — it is compiled into the dsperse binary.\n"
  },
  {
    "path": "docs/overview.md",
    "content": "# DSperse: Distributed zkML\n\n## Overview\n\nDSperse is a proving-system-agnostic intelligent slicer for verifiable AI. It decomposes ONNX neural network models into circuit-compatible segments and orchestrates compilation, inference, proving, and verification across pluggable ZK backends.\n\n### Core Purpose\nThe project solves a significant challenge in zkML (zero-knowledge machine learning) by introducing intelligent model slicing that enables distributed proof computation across heterogeneous hardware.\n\n### Key Technical Innovation\nThe main innovation is the concept of \"model slicing\" where:\n1. Instead of processing an entire neural network at once\n2. The system splits the neural network into manageable segments\n3. Each segment can be processed independently for analysis, inference, or proof generation\n\n### Primary Goals\n1. **Model Slicing**\n    - Split neural network models into individual layers or custom segments\n    - Support ONNX models\n    - Enable detailed analysis of model components\n\n2. **Distributed Computation**\n    - Break down large ML models into manageable pieces\n    - Enable parallel processing across multiple machines\n    - Support both GPU and non-GPU nodes\n\n3. **Resource Optimization**\n    - Reduce RAM requirements through model splitting\n    - Implement efficient inference pipelines\n    - Better manage compute resources\n\n4. **System Flexibility**\n    - Support for different model types\n    - Configurable slicing strategies\n    - Adaptable to different hardware capabilities\n\n5. **Zero-Knowledge Proofs**\n    - Generate proofs for sliced model execution via JSTprove integration\n    - Proving-system-agnostic design supporting Expander and Remainder backends\n    - Optimize proof generation for distributed environments\n\n### Implementation Framework\n- Built on top of existing tools:\n    - ONNX for model representation and interoperability\n    - JSTprove (`jstprove_circuits` Rust crate) for zero-knowledge proof generation\n    - Expander and Remainder as the underlying proving systems\n\n- Comprehensive CLI interface for:\n    - Model slicing\n    - Inference\n    - Proof generation\n    - Proof verification\n\n- Designed to work with various neural network architectures\n- Focuses on practical applications of zkML technology\n"
  },
  {
    "path": "docs/uv_packaging.md",
    "content": "# Developer Guide\n\nThis document provides a guide for developers who contribute to the project.\n\n## Build System\n\nThe project uses [maturin](https://www.maturin.rs/) as its build backend. The native Rust extension is compiled via PyO3 and exposed as `dsperse._native`. There are no Python-level dependencies beyond the compiled extension itself.\n\nThe build configuration in `pyproject.toml`:\n\n```toml\n[build-system]\nrequires = [\"maturin>=1.0,<2.0\"]\nbuild-backend = \"maturin\"\n\n[tool.maturin]\nfeatures = [\"python\"]\nmodule-name = \"dsperse._native\"\npython-source = \"python\"\nmanifest-path = \"crates/dsperse/Cargo.toml\"\n```\n\n## Local Development\n\nCreate a virtual environment and build the extension in development mode:\n\n```sh\nuv venv\nsource .venv/bin/activate\nmaturin develop --features python\n```\n\nThis compiles the Rust crate and installs the resulting native extension into the active virtualenv. Re-run `maturin develop` after any Rust code changes.\n\n## Building a Wheel\n\n```sh\nmaturin build --release --features python\n```\n\nThe output wheel is self-contained with no additional Python dependencies.\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"maturin>=1.0,<2.0\"]\nbuild-backend = \"maturin\"\n\n[project]\nname = \"dsperse\"\nversion = \"0.0.0\"\ndescription = \"Distributed zkML Toolkit\"\nreadme = \"README.md\"\nrequires-python = \">=3.9\"\nlicense = { file = \"LICENSE\" }\nauthors = [{ name = \"Inference Labs\", email = \"info@inferencelabs.com\" }]\n\n[project.scripts]\ndsperse = \"dsperse.cli:main\"\n\n[tool.maturin]\nfeatures = [\"python\"]\nmodule-name = \"dsperse._native\"\npython-source = \"python\"\nmanifest-path = \"crates/dsperse/Cargo.toml\"\n"
  },
  {
    "path": "python/dsperse/__init__.py",
    "content": "from dsperse._native import (\n    slice_model,\n    compile_slices,\n    run_inference,\n    prove_run,\n    verify_run,\n    setup_holographic,\n)\n\n__all__ = [\n    \"slice_model\",\n    \"compile_slices\",\n    \"run_inference\",\n    \"prove_run\",\n    \"verify_run\",\n    \"setup_holographic\",\n]\n"
  },
  {
    "path": "python/dsperse/cli.py",
    "content": "import sys\n\n\ndef main():\n    try:\n        from dsperse._native import cli_main\n    except ImportError:\n        print(\"dsperse native extension not found; install with: pip install dsperse\", file=sys.stderr)\n        return 1\n\n    try:\n        cli_main()\n    except SystemExit:\n        raise\n    except Exception as e:  # noqa: BLE001 - top-level CLI wrapper to convert any error to exit code 1\n        print(f\"error: {e}\", file=sys.stderr)\n        return 1\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "rust-toolchain.toml",
    "content": "[toolchain]\nchannel = \"nightly-2026-02-22\"\ncomponents = [\"clippy\", \"rustfmt\"]\n"
  }
]